## The forward algorithm

This section explains the factorized transition used for the MSM model, and provides a simulation of what happens in the actual implementation.

For a Hidden Markov Model (HMM) let:  
  - $\Pi_t$ be the joint distribution of latent components $M_t$.  
  - $f(x_t | M_t)$ be the emission/data likelihood of the model.  
  - $A$: the transition matrix.  

Then each step of the forward algorithm is of the form:  
$$
\Pi_{t+1}=\frac{f(x_t | M_t)\odot \Pi_t A}{f(x_t | M_t)* \Pi_t A}
$$  

where $\odot$ is the Hadamard product, and $*$ is the inner product. For notational convenience we also introduce the predictive distribution $Q_{t+1}= \Pi_t A$, so that the step equation becomes:

$$
\Pi_{t+1}=\frac{f(x_t | M_t)\odot Q_{t+1}}{f(x_t | M_t)* Q_{t+1}}
$$

Below follows an example.

In [1]:
import jax.numpy as jnp

pi = jnp.full(5, 0.2)
f = jnp.arange(1, 6)/10
A = jnp.full([5, 5], 0.2)
Q = jnp.dot(pi, A)
num = f * Q
pi_tp1 = num/jnp.sum(num)

print(pi_tp1)

[0.06666667 0.13333334 0.20000002 0.26666668 0.33333334]


When a lot of steps are involved, multiplications of probabilities tend to underflow, therefore it is convenient to work in the log-space instead. Recall that:

  - for $\vec{c} = \vec{a} \odot \vec{b} \implies \log \vec{c} = \log \vec{a} + \log \vec{b}$.  
  - for $d=\vec{a}*\vec{b} \implies \log d = \log \left( \sum_{i=1}^n \exp (\log a_i + \log b_i) \right)=\text{logsumexp}(\log \vec{a} + \log \vec{b})=\text{logsumexp}(\vec{c})$
  - for $y=\vec{v}\mathbf M \implies \log y = \log \sum_{i=1}^n \exp (\log v_i+\log \mathbf M_{i,:})$

Then the forward update step can be made in the log. space:

$$
\log \Pi_{t+1} = \log f(x_t | M_t) + \log \sum_{i=1}^n \exp (\log \Pi_{t,i}+\log A_{i,:}) - \text{logsumexp} \left( \log f(x_t | M_t) + \log \sum_{i=1}^n \exp (\log \Pi_{t,i}+\log A_{i,:}) \right)
$$

Or equivalently, using the predictive distribution notation: $\log Q_{t+1}=\log \sum_{i=1}^n \exp (\log \Pi_{t,i}+\log A_{i,:})$
  
$$
\log \Pi_{t+1} = \log f(x_t | M_t) + \log Q_{t+1} - \text{logsumexp} \left( \log f(x_t | M_t) + \log Q_{t+1} \right)
$$

In [3]:
from jax.scipy.special import logsumexp

log_f = jnp.log(f)
log_pi = jnp.log(pi)
log_A = jnp.log(A)

def joint_predictive(log_pi, log_A):
    return logsumexp(log_pi[:, None] + log_A, axis=0)

log_Q = joint_predictive(pi, A)

#consistency check
if log_Q.all() != jnp.log(Q).all(): print("Values are different")

# Note the None index adds 1 dimension/column during broadcasting
num = log_f + log_Q
log_pi_tp1 = num - logsumexp(num)

# besides small approximation errors, the resulting distribution is the same as in the base space
print(jnp.exp(log_pi_tp1))

[0.06666666 0.13333333 0.2        0.26666665 0.3333333 ]


As shown in the example, the 2 algorithms produce the same result

## A simplified case: factorized latent transitions.

The computations in the previous section have the drawback that scale exponentially with the number of latent variables $M_i$. In particular, consider the homogenous case in which there are $K$ latent variables which can assume the same $j$ possible values. Then at each forward step, $\Pi_t$ is a distribution vector of $j^k$ elements that is updated by multiplying it with a $(j^K \times j^K)$ transition matrix $A$, thus the total complexity at each step is $\mathcal O(j^{2K})$.

The complexity can be reduced to $\mathcal O(Kj^{K+1})$ for HMMs in which the latent variables evolve independently. This reduction is possible because $A$ can be represented as the kronecker product of $K$ marginal transition matrices $A^{(k)}= P(M_{k,t} | M_{k, t-1})$:  
$$A = P(M_t | M_{t-1})=\otimes_{k=0}^K A^{(k)}$$
  where $\otimes$ represents the sequentially applied kronecker products.  

It is thus possible to make the forward transitions independently, and aggregate them toghether only for the emission adjustment. This constructin of the transition tensor corresponds to a tucker tensor decomposition with a trivial core tensor composed only by 1s.

Let $\tilde \Pi_{t}=\text{reshape}(\Pi_t) \in \mathbb R^{\otimes_K j}$ be the joint distribution at time $t$ reshaped as a tensor of $K$ dimensions, each of length $j$. Then the transition can be done using:


\begin{align*}
\mathrm{apply\_transitions}\bigl(\tilde \Pi_{t},\{A^{(k)}\}\bigr)
&=
\underbrace{\Bigl(\bigl(\bigl(\tilde \Pi_{t}
   \times_{0}(A^{(0)})^{\!\top}\bigr)
   \times_{1}(A^{(1)})^{\!\top}\bigr)
   \times_{2}\cdots\Bigr)}_{K\ \text{mode‑products}} 
   \times_{K-1}(A^{(K-1)})^{\!\top}, \\[6pt]
\end{align*}

Where for a tensor $\mathcal X$ of order $K$ and a transition matrix $A \in \mathbb R^{J_k \times I_k}$ the mode-$k$ product $\times_k$ is defined by:

\begin{align*}

(\mathcal X\times_k A)_{i_0\cdots i_{i-1}\,j\,i_{k+1}\cdots i_{K-1}}
&=
\sum_{i_k=1}^{I_K}
A_{\,j\,i_k}\;
\mathcal X_{i_0\,i_1\,\dots\,i_{K-1}}.

\end{align*}

where $I_k$ are the possible values that the $k$-th latent state can take.

What follows are different implementation of this factorized algoirhm. From tests, an implementation using only JAX constructs proves challenging, due to tracing, therefore the implementation currently relies on python loops. The first one works with normalized probabilities, while the other 2 versions use log-probabilities.

In [None]:
def make_apply_scaled(*As):
    def apply(p):
        # p:    shape (S, S, ..., S), sums to 1
        # emissions: same shape, non-negative
        for axis, A in enumerate(As):
            # static moveaxis / tensordot / moveaxis
            p = jnp.moveaxis(p, axis, -1)
            p = jnp.tensordot(p, A, axes=([-1], [0]))
            p = jnp.moveaxis(p, -1, axis)

        return p

    return apply

# A0 = jnp.array([[0.7, 0.3], [0.4, 0.6]])
# A1 = jnp.array([[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.5, 0.1, 0.4]])
# A2 = jnp.array([[0.9, 0.1], [0.2, 0.8]])

A0 = jnp.array([[0.7, 0.3], [0.4, 0.6]])
A1 = jnp.array([[0.1, 0.9], [0.6, 0.4]])
A2 = jnp.array([[0.9, 0.1], [0.2, 0.8]])


# prior = jnp.array([
#     0.02, 0.03, 0.05, 0.10, 0.07, 0.08,
#     0.04, 0.06, 0.12, 0.13, 0.10, 0.20
# ])

prior = jnp.array([
    0.06, 0.09, 0.17, 0.23, 0.07, 0.08, 0.10, 0.20
])

A_tensor = (A0, A1, A2)
apply_fn = make_apply_scaled(*A_tensor)

dims = tuple(A.shape[0] for A in A_tensor)
#reshape(n_latent * [len(marg_prob)])


prior_tensor = prior.reshape(*dims)
pred_tensor = apply_fn(prior_tensor)

pred_fast = pred_tensor.reshape(-1)

log_pred_fast = jnp.log(pred_fast)


# Full-joint predictive and its log for reference
A_joint = jnp.kron(jnp.kron(A0, A1), A2)
pred_full = prior @ A_joint
log_pred_full = jnp.log(pred_full)


# print("Log full-joint predictive:", log_pred_full)
# print("Log fast-space predictive:", log_pred_fast)
print("Match in log-space:", jnp.allclose(log_pred_full, log_pred_fast))

Match in log-space: True


In [None]:
def apply_transitions(prior_tensor, transition_matrices):
    result = prior_tensor
    for axis, A in enumerate(transition_matrices):
        # Move axis to front
        result = jnp.moveaxis(result, axis, 0)
        # Apply transition matrix along this axis
        result = jnp.tensordot(A.T, result, axes=1)
        # Move axis back to original position
        result = jnp.moveaxis(result, 0, axis)
    return result

# Transition matrices for 3 latent variables with different number of states
A0 = jnp.array([[0.7, 0.3],
                [0.4, 0.6]])
A1 = jnp.array([[0.1, 0.6, 0.3],
                [0.3, 0.4, 0.3],
                [0.5, 0.1, 0.4]])
A2 = jnp.array([[0.9, 0.1],
                [0.2, 0.8]])

# Prior over joint states (2 x 3 x 2 = 12 states)
prior = jnp.array([
    0.02, 0.03, 0.05, 0.10, 0.07, 0.08,
    0.04, 0.06, 0.12, 0.13, 0.10, 0.20
])
prior_tensor = prior.reshape(2, 3, 2)

# Apply transitions (using the function you already have)
predictive_tensor = apply_transitions(prior_tensor, [A0, A1, A2])
predictive_fast = predictive_tensor.reshape(-1)

# Full joint transition via kron product for reference
A_joint = jnp.kron(jnp.kron(A0, A1), A2)
predictive_full = prior @ A_joint

print("Match:", jnp.allclose(predictive_full, predictive_fast))


Match: True


In [None]:
def apply_transitions(prior_log_tensor, log_transition_matrices):
    """
    Applies each axis-wise transition in log-space using an optimized broadcasted log-sum-exp.

    prior_log_tensor: log-prob tensor of shape (d0, d1, ..., dK)
    log_transition_matrices: list of K matrices, where
      log_transition_matrices[k] is (d_k, d_k').
    """
    result = prior_log_tensor
    for axis, logA in enumerate(log_transition_matrices):
        # 1) Move the k-th latent axis to the last position
        r = jnp.moveaxis(result, axis, -1)  # shape (..., old_dim)

        # 2) Broadcast-add logA: r[..., :, None] has shape (..., old_dim, 1)
        #    logA[None, ...] has shape (1, old_dim, new_dim)
        #    result t has shape (..., old_dim, new_dim)
        t = r[..., :, None] + logA[None, :, :]

        # 3) log-sum-exp over the old state dimension (axis -2)
        s = logsumexp(t, axis=-2)  # shape (..., new_dim)

        # 4) Move the new state axis back to its original position
        result = jnp.moveaxis(s, -1, axis)

    return result


# Example usage to verify equivalence
if __name__ == '__main__':
    # Factorized transition mats
    A0 = jnp.array([[0.7, 0.3],
                    [0.4, 0.6]])
    A1 = jnp.array([[0.1, 0.6, 0.3],
                    [0.3, 0.4, 0.3],
                    [0.5, 0.1, 0.4]])
    A2 = jnp.array([[0.9, 0.1],
                    [0.2, 0.8]])
    logA0, logA1, logA2 = jnp.log(A0), jnp.log(A1), jnp.log(A2)

    # Prior over 2×3×2 joint grid
    prior = jnp.array([
        0.02, 0.03, 0.05, 0.10, 0.07, 0.08,
        0.04, 0.06, 0.12, 0.13, 0.10, 0.20
    ])
    prior_tensor = prior.reshape(2, 3, 2)
    log_prior_tensor = jnp.log(prior_tensor)

    # Fast log-space predictive
    log_pred_tensor = apply_transitions(log_prior_tensor, [logA0, logA1, logA2])
    log_pred_fast = log_pred_tensor.reshape(-1)

    # Full-joint predictive and its log for reference
    A_joint = jnp.kron(jnp.kron(A0, A1), A2)
    pred_full = prior @ A_joint
    log_pred_full = jnp.log(pred_full)

    # Compare results
    print("Match in log-space:", jnp.allclose(log_pred_full, log_pred_fast))


Match in log-space: True


In [None]:
import jax

# Factory to create a JIT-ed function with transitions embedded in closure
def make_apply_transitions(*log_transition_matrices):
    """
    Returns a JIT-ed function with log_transition_matrices baked into the closure.
    This avoids static_argnums hashing issues and is fully compatible with JAX.
    """
    @jax.jit
    def apply(prior_log_tensor):
        result = prior_log_tensor
        for axis, logA in enumerate(log_transition_matrices):
            r = jnp.moveaxis(result, axis, -1)
            t = r[..., :, None] + logA[None, :, :]
            s = logsumexp(t, axis=-2)
            result = jnp.moveaxis(s, -1, axis)
        return result
    return apply

# Example usage with full test:
if __name__ == '__main__':
    A0 = jnp.array([[0.7, 0.3], [0.4, 0.6]])
    A1 = jnp.array([[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.5, 0.1, 0.4]])
    A2 = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    logA0, logA1, logA2 = jnp.log(A0), jnp.log(A1), jnp.log(A2)

    # Create the apply function with static mats
    apply_transitions = make_apply_transitions(logA0, logA1, logA2)

    # Prior over 2×3×2 joint grid
    prior = jnp.array([
        0.02, 0.03, 0.05, 0.10, 0.07, 0.08,
        0.04, 0.06, 0.12, 0.13, 0.10, 0.20
    ])
    prior_tensor = prior.reshape(2, 3, 2)
    log_prior_tensor = jnp.log(prior_tensor)

    # Fast log-space predictive
    log_pred_tensor = apply_transitions(log_prior_tensor)
    log_pred_fast = log_pred_tensor.reshape(-1)

    # Full-joint predictive and its log for reference
    A_joint = jnp.kron(jnp.kron(A0, A1), A2)
    pred_full = prior @ A_joint
    log_pred_full = jnp.log(pred_full)

    print("Match in log-space:", jnp.allclose(log_pred_full, log_pred_fast))


Match in log-space: True


Trying to avoid python loops entirely creates tracing problems. The issue is that JAX loop operators work with traced indexes, which cannot be put inside functions that require static arguements. Below, an example:

In [None]:
import jax.numpy as jnp
from jax.lax import fori_loop

A0 = jnp.array([[0.7, 0.3], [0.4, 0.6]])
A1 = jnp.array([[0.6, 0.7], [0.6, 0.4]])
A2 = jnp.array([[0.9, 0.1], [0.2, 0.8]])

#make reshape outside of looped section, have the tensor shape as input
dims = tuple(A.shape[0] for A in (A0, A1, A2))
prior = jnp.array([0.27, 0.08, 0.04, 0.06, 0.12, 0.13, 0.10, 0.20])
prior_reshaped = prior.reshape(*dims)

def make_predictive(prior_tensor:jnp.ndarray, *transition_matrices):
    stacked = jnp.stack(transition_matrices)
    def step(axis, prior_tensor):
        factor_transition_matrix = jnp.take(stacked, axis, axis=0)
        predictive_tensor = jnp.moveaxis(prior_tensor, axis, -1)
        predictive_tensor = jnp.tensordot(predictive_tensor, factor_transition_matrix, axes=([-1], [0]))
        return jnp.moveaxis(predictive_tensor, -1, axis)
    
    return fori_loop(0, len(transition_matrices), step, prior_tensor)

pr_fn = make_predictive(prior_reshaped, A0, A1, A2)

TypeError: cannot create weak reference to 'staticmethod' object