In [4]:
import numpy as np

def initialize_parameters(N, K, seed=None):
    #Initialize parameters randomly

    if seed is not None:
        np.random.seed(seed)

    pi = np.random.dirichlet(alpha=np.ones(K))    

    check = np.allclose(sum(pi), 1, atol=1e-6)
    if check != 1:
        print("Inconsistent boundary conditions")


    # Transition matrix. it has K rows and K columns. In particular
    # the sum of probabilities from the same initial state should be 1
    gamma = np.random.dirichlet(alpha=np.ones(K), size=K).T

    sum_columns = np.sum(gamma, axis = 0)
    check = np.allclose(sum_columns, np.ones(K), atol=1e-6)
    if check != 1:
        print("Inconsistent boundary conditions")


    # I initialize the r matrix in this in order to satisfy the boundary condition (sum of 
    # columns must be 1 foreach column)
    r = np.random.dirichlet(alpha=np.ones(N + 1), size=K).T

    sum_columns = np.sum(r, axis = 0)
    check = np.allclose(sum_columns, np.ones(K), atol=1e-6)
    if check != 1:
        print("Inconsistent boundary conditions")

    return pi, gamma, r

def forward(y, pi, gamma, r, T, K):
    
    alpha = np.zeros((T, K))   # alpha has T elements of dimension K each

    alpha[0] = pi * r[y[0]]   # r[y[0]] is the first row of r: r_y1,k  

    for t in range(1, T):   
        for k in range(K):
            alpha[t, k] = r[y[t-1], k] * np.sum(alpha[t - 1] * gamma[k])      #gamma[k] is the k'th row of gamma
    return alpha

def backward(y, gamma, r, T, K):
    
    beta = np.zeros((T, K))
    beta[-1] = 1
    for t in range(T - 2, -1, -1):
        for k in range(K):
            beta[t, k] = r[y[t+1], k] * np.sum(beta[t + 1] * gamma[k])
    return beta

def e_step(y, pi, gamma, r, T, K):
    
    alpha = forward(y, pi, gamma, r, T, K)
    beta = backward(y, gamma, r, T, K)

    # it must have T-1 elements because it's the P(z_t = k and z_t+1 = l | ...)
    xi = np.zeros((T - 1, K, K))
    for t in range(T - 1):
        for k in range(K):
            for l in range(K):
                xi[t, k, l] = alpha[t, k] * gamma[k, l] * r[y[t + 1], l] * beta[t + 1, l]
        #renormalization
        xi[t] /= np.sum(xi[t])

    zeta = np.zeros((T, K))
    for t in range(0, T):
        for k in range(K):
            zeta[t,k] = alpha[t, k] * beta[t, k]

        zeta[t] /= np.sum(zeta[t])


    return zeta, xi


def m_step(y, zeta, xi, N, K):
    
    T = len(y)
    # update of pi
    pi = np.zeros(K)
    pi = zeta[0]    

    # update of gamma
    gamma = np.zeros((K,K))
    for k in range(K):
        denominator = np.sum(zeta[:, k])
    for l in range(K):
        numerator = np.sum(xi[:, k, l])
        gamma[k,l] = numerator/denominator


    # update of r 
    r = np.zeros((N + 1, K))

    for k in range(K):
        denominator = 0
        for i in range(N + 1):
            numerator = 0
            for t in range(T):
                if y[t]==i:
                    numerator += zeta[t, k]
                    denominator += zeta[t,k]
        
            r[i,k] = numerator
        
        r[:, k] /= denominator

    return pi, gamma, r

def em_algorithm(y, N, K, max_iter=100, tol=1e-2):
    
    T = len(y)
    pi, gamma, r = initialize_parameters(N, K)

        

    for iteration in range(max_iter):
        print(iteration)
        # E-step
        zeta, xi = e_step(y, pi, gamma, r, T, K)
        

        # M-step
        pi_updated, gamma_updated, r_updated = m_step(y, zeta, xi, N, K)

    

        # i compute the delta using the relative distance, and using the Frobenius norm
        # delta_pi = np.linalg.norm(pi_updated - pi, ord='fro') / (np.linalg.norm(pi, ord='fro') + 1e-10)
        delta_gamma = np.linalg.norm(gamma_updated - gamma, ord='fro')
        delta_r = np.linalg.norm(r_updated - r, ord='fro')
        

        # if delta_pi < tol and delta_gamma < tol and delta_r < tol:
        if delta_gamma < tol and delta_r < tol:
            print(f"Converged at iteration {iteration}")
            break

        pi = pi_updated
        gamma = gamma_updated
        r = r_updated

    return pi_updated, gamma_updated, r_updated

# Example usage
N = 5  # Number of neurons
K = 3  # Number of hidden states
T = 100  # Number of time steps
y = np.random.randint(0, N + 1, size=T)  # Simulated observations


pi, gamma, r = em_algorithm(y, N, K)

0
[[[2.64666482e-01 6.90972452e-03 3.20876470e-01]
  [1.10808582e-01 3.52992769e-03 2.91943044e-01]
  [3.31258007e-04 9.40924899e-05 8.40418910e-04]]

 [[7.55722643e-01 2.39373863e-02 1.60640107e-02]
  [1.70183519e-01 6.57752366e-03 7.86130885e-03]
  [1.41484249e-02 4.87583696e-03 6.29346972e-04]]

 [[3.18982201e-01 1.05205485e-01 6.21207400e-03]
  [5.04754894e-02 2.03134223e-02 2.13617348e-03]
  [1.07293284e-01 3.85009336e-01 4.37253417e-03]]

 [[5.04037751e-01 3.25330588e-02 3.09974184e-01]
  [1.84053761e-02 1.44956320e-03 2.45976131e-02]
  [3.64658035e-02 2.56079262e-02 4.69287246e-02]]

 [[3.49122727e-01 1.33117220e-02 4.10365779e-01]
  [6.21366411e-02 2.89090963e-03 1.58718152e-01]
  [8.91476219e-04 3.69821937e-04 2.19277102e-03]]

 [[1.12143715e-01 2.86343326e-02 4.02145688e-01]
  [2.75833821e-02 8.59389940e-03 2.14952154e-01]
  [1.82542199e-02 5.07110018e-02 1.36981607e-01]]

 [[1.43855267e-01 2.38363165e-02 6.29940064e-01]
  [4.56989402e-03 9.23953533e-04 4.34875960e-02]
  [1.2

  r[:, k] /= denominator


[[[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]

 [[nan nan nan]
  [nan nan nan]
  [nan nan nan]]



In [7]:
y = np.random.randint(0, N + 1, size=T)
y

array([3, 4, 0, 1, 2, 3, 5, 1, 2, 2, 2, 5, 2, 3, 5, 1, 0, 3, 2, 5, 1, 3,
       3, 2, 1, 3, 4, 0, 4, 2, 0, 0, 2, 0, 2, 2, 5, 1, 5, 4, 0, 5, 3, 4,
       4, 1, 2, 2, 3, 4, 3, 1, 2, 3, 0, 5, 1, 0, 0, 2, 0, 4, 0, 5, 1, 1,
       3, 2, 3, 2, 5, 5, 3, 3, 1, 3, 1, 4, 2, 4, 5, 3, 1, 1, 3, 3, 2, 1,
       2, 3, 3, 4, 5, 5, 4, 0, 2, 0, 1, 1])