In [273]:
import numpy as np

In [388]:
def sample_pi(v):
    return np.random.dirichlet(v)

def sample_gamma(u):
    K = len(u)
    gamma_matrix = np.zeros((K, K))
    
    # Here is normalized BY ROWS!!!
    for k in range(K):
        gamma_matrix[k] = np.random.dirichlet(u[k])
    return gamma_matrix

def sample_r(w):
    K = len(w[0])
    N = len(w) - 1
    r_matrix = np.zeros((N + 1, K))

    # Here is normalized BY COLUMNS!!!
    for k in range(K):
        r_matrix[:, k] = np.random.dirichlet(w[:, k])
    return r_matrix


def update_v(v, z):
    v[int(z[0])] += 1
    
    return v


def update_u(u, z):
    T = len(z)
    
    for t in range(T-1):
        u[int(z[t]), int(z[t+1])] += 1
        
    return u


def update_w(w, y, z):
    T = len(z)
    
    for t in range(T):
        w[int(y[t]), int(z[t])] += 1
        
    return w


def sample_z(y, z_prev, r, gamma):
    T = len(z_prev)
    K = len(gamma)
    z_new = np.zeros(T)
    eta = np.zeros((T, K))
    
    for k in range(K):
        eta[0, k] = r[y[0], k] * gamma[k, int(z_prev[1])]
    
    eta[0, :] /= np.sum(eta[0, :])
    z_new[0] = np.random.choice(K, p=eta[0, :])
    
    for t in range(1, T-1):
        for k in range(K):
            m = int(z_new[t-1])
            l = int(z_prev[t+1])
            eta[t, k] = r[y[t], k] * gamma[m, k] * gamma[k, l]
        
        eta[t, :] /= np.sum(eta[t, :])
        
        z_new[t] = np.random.choice(K, p=eta[t, :])
    
    for k in range(K):
        eta[T-1, k] = r[y[T-1], k] * gamma[int(z_new[T-2]), k]
    
    eta[T-1, :] /= np.sum(eta[T-1, :])
    z_new[T-1] = np.random.choice(K, p=eta[T-1, :])
    
    return z_new

In [436]:
def gibbs_sampling(y, z, v, u, w, num_iterations):
    for iter in range(num_iterations):
        if iter != 0 and iter%100 == 0:
            print(f"Iteration {iter}...")
      
        pi = sample_pi(v)
        gamma_matrix = sample_gamma(u)
        r_matrix = sample_r(w)
        
        z = sample_z(y, z, r_matrix, gamma_matrix)
        
        v = update_v(v, z)
        u = update_u(u, z)
        w = update_w(w, y, z)
        
    return pi, gamma_matrix, r_matrix, z

In [437]:
def initialize_parameters(N, K, seed=None):
    if seed is not None:
        np.random.seed(seed)

    # Probability vector of length K
    pi = np.random.dirichlet(alpha=np.ones(K))    

    assert(np.allclose(sum(pi), 1, atol=1e-6) == 1)

    # Transition matrix (K x K), where each column sums to 1, because a column
    # represents the jump probabilities from a fixed initial state.
    gamma = np.random.dirichlet(alpha=np.ones(K), size=K)

    assert(np.allclose(np.sum(gamma, axis = 1), np.ones(K), atol=1e-6) == 1)

    # r matrix (N+1 x K), where each column sums to 1, because a column
    # represents the firing probability of a neuron in a fixed state.
    r = np.random.dirichlet(alpha=np.ones(N + 1), size=K).T

    assert(np.allclose(np.sum(r, axis = 0), np.ones(K), atol=1e-6) == 1)

    return pi, gamma, r

def generate_hmm_data(T, N, K, pi, gamma, r):
    z = np.zeros(T, dtype=int)
    y = np.zeros(T, dtype=int)

    # Initialization t=0. I extract the value of K following the init prob
    z[0] = np.random.choice(K, p=pi)
    y[0] = np.random.choice(N + 1, p=r[:, z[0]])

    #data generation
    for t in range(1, T):
        # I take the column of gamma corresponding to z[t-1] and extract from this distribution
        z[t] = np.random.choice(K, p=gamma[z[t-1]])
        
        # I take the column of r corresponding to z[t] and extract from this prob distribution
        y[t] = np.random.choice(N + 1, p=r[:, z[t]])

    return y, z

In [438]:
# Example usage
N = 10  # Number of neurons
K = 3  # Number of hidden states
T = 100  # Number of time steps

v = np.ones(K)
u = np.ones((K, K)) 
w = np.ones((N + 1, K))

In [439]:
pi, gamma, r = initialize_parameters(N, K)
y, z = generate_hmm_data(T, N, K, pi, gamma, r)

In [440]:
print("pi:\n", pi)
print("gamma:\n", gamma)
print("r:\n", r)

pi:
 [0.05151607 0.49418201 0.45430192]
gamma:
 [[0.04216746 0.93119775 0.02663478]
 [0.96372457 0.03518239 0.00109304]
 [0.00322633 0.26720441 0.72956926]]
r:
 [[0.00933707 0.21545334 0.09833038]
 [0.22163235 0.04088649 0.00931921]
 [0.09221103 0.13672462 0.01017301]
 [0.06971308 0.07661738 0.06749894]
 [0.08535719 0.01666407 0.17775383]
 [0.1353473  0.03221208 0.00788518]
 [0.09041921 0.10880854 0.30237284]
 [0.03374332 0.06288201 0.14726194]
 [0.00492853 0.01469476 0.12566256]
 [0.04132324 0.02647357 0.04292721]
 [0.21598768 0.26858314 0.01081489]]


In [441]:
pi_post, gamma_post, r_post, z_post = gibbs_sampling(y, z, v, u, w, 10000)

Iteration 50...
Iteration 100...
Iteration 150...
Iteration 200...
Iteration 250...
Iteration 300...
Iteration 350...
Iteration 400...
Iteration 450...
Iteration 500...
Iteration 550...
Iteration 600...
Iteration 650...
Iteration 700...
Iteration 750...
Iteration 800...
Iteration 850...
Iteration 900...
Iteration 950...
Iteration 1000...
Iteration 1050...
Iteration 1100...
Iteration 1150...
Iteration 1200...
Iteration 1250...
Iteration 1300...
Iteration 1350...
Iteration 1400...
Iteration 1450...
Iteration 1500...
Iteration 1550...
Iteration 1600...
Iteration 1650...
Iteration 1700...
Iteration 1750...
Iteration 1800...
Iteration 1850...
Iteration 1900...
Iteration 1950...
Iteration 2000...
Iteration 2050...
Iteration 2100...
Iteration 2150...
Iteration 2200...
Iteration 2250...
Iteration 2300...
Iteration 2350...
Iteration 2400...
Iteration 2450...
Iteration 2500...
Iteration 2550...
Iteration 2600...
Iteration 2650...
Iteration 2700...
Iteration 2750...
Iteration 2800...
Iteration 28

In [442]:
print("pi_post:\n", pi_post)
print("gamma_post:\n", gamma_post)
print("r_post:\n", r_post)

pi_post:
 [0.99490791 0.00196189 0.0031302 ]
gamma_post:
 [[0.02064456 0.56212217 0.41723327]
 [0.44388893 0.50310483 0.05300624]
 [0.19960122 0.27203444 0.52836434]]
r_post:
 [[2.86672979e-01 8.70352575e-03 3.15937359e-02]
 [1.73678685e-02 2.11622459e-01 3.72031563e-02]
 [8.95247745e-02 1.47807042e-01 3.11952638e-02]
 [1.40617558e-01 4.90184871e-02 7.09753113e-02]
 [2.95092835e-04 1.97258291e-03 1.76209804e-01]
 [1.84429158e-01 8.55196124e-02 5.46528204e-03]
 [8.28596437e-03 1.67709465e-01 3.64625528e-01]
 [1.06361063e-03 3.54991940e-03 1.35564186e-01]
 [1.47126596e-01 8.88751640e-04 3.64379930e-04]
 [3.86005643e-03 3.17335541e-02 8.87444814e-02]
 [1.20756342e-01 2.91474600e-01 5.80588714e-02]]
