In [58]:
import numpy as np

def initialize_parameters(N, K, seed=None):
    #Initialize parameters randomly

    if seed is not None:
        np.random.seed(seed)

    pi = np.random.dirichlet(alpha=np.ones(K))    

    check = np.allclose(sum(pi), 1, atol=1e-6)
    if check != 1:
        print("Inconsistent boundary conditions")


    # Transition matrix. it has K rows and K columns. In particular
    # the sum of probabilities from the same initial state should be 1
    gamma = np.random.dirichlet(alpha=np.ones(K), size=K).T

    sum_columns = np.sum(gamma, axis = 0)
    check = np.allclose(sum_columns, np.ones(K), atol=1e-6)
    if check != 1:
        print("Inconsistent boundary conditions")


    # I initialize the r matrix in this in order to satisfy the boundary condition (sum of 
    # columns must be 1 foreach column)
    r = np.random.dirichlet(alpha=np.ones(N + 1), size=K).T

    sum_columns = np.sum(r, axis = 0)
    check = np.allclose(sum_columns, np.ones(K), atol=1e-6)
    if check != 1:
        print("Inconsistent boundary conditions")

    return pi, gamma, r

def forward(y, pi, gamma, r, T, K):
    
    alpha = np.zeros((T, K))   # alpha has T elements of dimension K each

    alpha[0] = pi * r[y[0]]   # r[y[0]] is the first row of r: r_y1,k  

    for t in range(1, T):   
        for k in range(K):
            alpha[t, k] = r[y[t-1], k] * np.sum(alpha[t - 1] * gamma[k])      #gamma[k] is the k'th row of gamma
    return alpha

def backward(y, gamma, r, T, K):
    
    beta = np.zeros((T, K))
    beta[-1] = 1
    for t in range(T - 2, -1, -1):
        for k in range(K):
            beta[t, k] = r[y[t+1], k] * np.sum(beta[t + 1] * gamma[k])
    return beta

def e_step(y, pi, gamma, r, T, K):
    
    alpha = forward(y, pi, gamma, r, T, K)
    beta = backward(y, gamma, r, T, K)
    # print("alpha = ", alpha)
    # print("beta = ", beta)

    # it must have T-1 elements because it's the P(z_t = k and z_t+1 = l | ...)
    xi = np.zeros((T - 1, K, K))
    for t in range(T - 1):
        for k in range(K):
            for l in range(K):
                xi[t, k, l] = alpha[t, k] * gamma[k, l] * r[y[t + 1], l] * beta[t + 1, l]
        #renormalization
        xi[t] /= np.sum(xi[t])

    zeta = np.zeros((T, K))
    for t in range(0, T):
        for k in range(K):
            zeta[t,k] = alpha[t, k] * beta[t, k]

        zeta[t] /= np.sum(zeta[t])

    print("zeta = ", zeta)
    print("xi = ", xi)


    return zeta, xi


def m_step(y, zeta, xi, N, K):
    
    T = len(y)
    # update of pi
    pi = np.zeros(K)
    pi = zeta[0]    

    # update of gamma
    gamma = np.zeros((K,K))
    for k in range(K):
        denominator = np.sum(zeta[:, k])
        for l in range(K):
            numerator = np.sum(xi[:, k, l])
            gamma[k,l] = numerator/denominator
        


    # update of r 
    r = np.zeros((N + 1, K))

    for k in range(K):
        denominator = 0
        for i in range(N + 1):
            numerator = 0
            for t in range(T):
                if y[t]==i:
                    numerator += zeta[t, k]
                    denominator += zeta[t,k]
        
            r[i,k] = numerator
        
        r[:, k] /= denominator

    return pi, gamma, r

def em_algorithm(y, N, K, max_iter=100, tol=1e-2):
    
    T = len(y)
    pi, gamma, r = initialize_parameters(N, K)

        

    for iteration in range(max_iter):
        print(iteration)
        # E-step
        zeta, xi = e_step(y, pi, gamma, r, T, K)


        # M-step
        pi_updated, gamma_updated, r_updated = m_step(y, zeta, xi, N, K)

        print("pi_upd = ", pi_updated)
        print("gamma_upd = ", gamma_updated)
        print("r_upd = ", r_updated)

    

        # i compute the delta using the relative distance, and using the Frobenius norm
        # delta_pi = np.linalg.norm(pi_updated - pi, ord='fro') / (np.linalg.norm(pi, ord='fro') + 1e-10)
        delta_gamma = np.linalg.norm(gamma_updated - gamma, ord='fro')
        delta_r = np.linalg.norm(r_updated - r, ord='fro')
        

        # if delta_pi < tol and delta_gamma < tol and delta_r < tol:
        if delta_gamma < tol and delta_r < tol:
            print(f"Converged at iteration {iteration}")
            break

        pi = pi_updated
        gamma = gamma_updated
        r = r_updated

    return pi_updated, gamma_updated, r_updated

# Example usage
N = 5  # Number of neurons
K = 3  # Number of hidden states
T = 5  # Number of time steps
y = np.random.randint(0, N + 1, size=T)  # Simulated observations


pi, gamma, r = em_algorithm(y, N, K)

0
zeta =  [[0.11960949 0.78951753 0.09087298]
 [0.09809135 0.7989462  0.10296245]
 [0.05502143 0.92920107 0.01577749]
 [0.17049263 0.78872741 0.04077996]
 [0.22909274 0.67052059 0.10038667]]
xi =  [[[1.11029868e-02 3.40229087e-01 4.74773800e-03]
  [8.43311218e-02 3.30120844e-01 1.67610763e-02]
  [2.66872414e-04 1.70399518e-01 4.20407554e-02]]

 [[7.57417183e-03 6.79422502e-02 9.53074158e-03]
  [3.03805168e-01 3.48139567e-01 1.77686242e-01]
  [1.30966573e-04 2.44792607e-02 6.07116323e-02]]

 [[2.56068629e-03 4.88830056e-02 1.02764991e-02]
  [1.74773426e-01 4.26216619e-01 3.26010297e-01]
  [6.00857111e-06 2.39004343e-03 8.88341444e-03]]

 [[2.04182721e-02 1.09028898e-01 6.63470553e-02]
  [2.42306629e-01 1.65288024e-01 3.65961221e-01]
  [2.34086179e-05 2.60454264e-03 2.80219495e-02]]]
pi_upd =  [0.11960949 0.78951753 0.09087298]
gamma_upd =  [[0.0619599  0.84200031 0.13520899]
 [0.20247272 0.31928411 0.22289119]
 [0.00121802 0.56979765 0.39813538]]
r_upd =  [[0.         0.         0.     