In [28]:
import numpy as np

In [102]:
def sample_pi(v):
    return np.random.dirichlet(v)

def sample_gamma(u):
    K = len(u)
    gamma_matrix = np.zeros((K, K))
    
    # Here is normalized BY ROWS!!!
    for k in range(K):
        gamma_matrix[k] = np.random.dirichlet(u[k])
    return gamma_matrix

def sample_r(w):
    K = len(w[0])
    N = len(w) - 1
    r_matrix = np.zeros((N + 1, K))

    # Here is normalized BY COLUMNS!!!
    for k in range(K):
        r_matrix[:, k] = np.random.dirichlet(w[:, k])
    return r_matrix

def sample_z(eta):
    T = len(eta)
    K = eta.shape[1]
    z = np.zeros(T)
    
    for t in range(T):
        dist = np.sum(eta[t, :, :, :], axis=(1, 2))  # Sum over m and l for each k
        dist /= np.sum(dist)
        z[t] = np.random.choice(K, p=dist)
    
    return z



def update_v(v, z):
    v[int(z[0])] += 1
    v /= np.sum(v)
    
    return v


def update_u(u, z):
    T = len(z)
    K = len(u)
    
    for t in range(T-1):
        u[int(z[t]), int(z[t+1])] += 1
                    
    for k in range(K):
        u[k] /= np.sum(u[k])
        
    return u


def update_w(w, y, z):
    T = len(z)
    K = len(w)
    
    for t in range(T):
        w[int(y[t]), int(z[t])] += 1
                    
    for k in range(K):
        w[k] /= np.sum(w[k])
        
    return w


def update_eta(y, r, gamma):

    K = len(r[0])
    T = len(y)
    eta = np.zeros((T, K, K, K))

    for t in range(T):
        for k in range(K):
            for l in range(K):
                for m in range(K):
                    eta[t, k, m, l] = r[y[t], k] * gamma[m, k] * gamma[k, l]
    
                #eta[t, k, :, l] /= np.sum(eta[t, k, :, l])
        
    return eta

In [112]:
def gibbs_sampling(y, z, v, u, w, num_iterations):    
    for iter in range(num_iterations):
        v = update_v(v, z)
        u = update_u(u, z)
        w = update_w(w, y, z)
        
        pi = sample_pi(v)
        gamma_matrix = sample_gamma(u)
        r_matrix = sample_r(w)
        
        eta = update_eta(y, r_matrix, gamma_matrix)
        z = sample_z(eta)
        
    return pi, gamma_matrix, r_matrix, z

In [138]:
# Example usage
N = 5  # Number of neurons
K = 3  # Number of hidden states
T = 5  # Number of time steps

v = np.ones(K)
u = np.ones((K, K)) 
w = np.ones((N + 1, K))

In [139]:
def initialize_parameters(N, K, seed=None):
    if seed is not None:
        np.random.seed(seed)

    # Probability vector of length K
    pi = np.random.dirichlet(alpha=np.ones(K))    

    assert(np.allclose(sum(pi), 1, atol=1e-6) == 1)

    # Transition matrix (K x K), where each column sums to 1, because a column
    # represents the jump probabilities from a fixed initial state.
    gamma = np.random.dirichlet(alpha=np.ones(K), size=K)

    assert(np.allclose(np.sum(gamma, axis = 1), np.ones(K), atol=1e-6) == 1)

    # r matrix (N+1 x K), where each column sums to 1, because a column
    # represents the firing probability of a neuron in a fixed state.
    r = np.random.dirichlet(alpha=np.ones(N + 1), size=K).T

    assert(np.allclose(np.sum(r, axis = 0), np.ones(K), atol=1e-6))

    return pi, gamma, r

def generate_hmm_data(T, N, K, pi, gamma, r):
    z = np.zeros(T, dtype=int)
    y = np.zeros(T, dtype=int)

    # Initialization t=0. I extract the value of K following the init prob
    z[0] = np.random.choice(K, p=pi)
    y[0] = np.random.choice(N + 1, p=r[:, z[0]])

    #data generation
    for t in range(1, T):
        # I take the column of gamma corresponding to z[t-1] and extract from this distribution
        z[t] = np.random.choice(K, p=gamma[z[t-1]])
        
        # I take the column of r corresponding to z[t] and extract from this prob distribution
        y[t] = np.random.choice(N + 1, p=r[:, z[t]])

    return y, z

In [140]:
pi, gamma, r = initialize_parameters(N, K)
y, z = generate_hmm_data(T, N, K, pi, gamma, r)

In [141]:
pi_post, gamma_post, r_post, z_post = gibbs_sampling(y, z, v, u, w, 1000)

ValueError: alpha <= 0

In [142]:
print("pi:", pi)
print("pi_post:", pi_post)

pi: [0.38302248 0.21703195 0.39994557]
pi_post: [0. 1. 0.]


In [143]:
print("gamma:", gamma)
print("gamma_post:", gamma_post)

gamma: [[0.11018196 0.39127889 0.49853915]
 [0.33376716 0.06772932 0.59850352]
 [0.33563811 0.50217952 0.16218237]]
gamma_post: [[1.00000000e+00 0.00000000e+00 0.00000000e+00]
 [5.34172237e-01 4.65827763e-01 0.00000000e+00]
 [8.39793654e-04 7.79579036e-01 2.19581170e-01]]


In [144]:
print("r:", r)
print("r_post:", r_post)

r: [[0.16746911 0.0376759  0.12274635]
 [0.35620538 0.06033349 0.03365968]
 [0.00808436 0.50230646 0.47461243]
 [0.00645759 0.03350373 0.00064324]
 [0.22157292 0.21493464 0.1598287 ]
 [0.24021063 0.15124578 0.2085096 ]]
r_post: [[8.01154364e-01 0.00000000e+00 0.00000000e+00]
 [4.64124184e-06 5.05675505e-02 6.25482548e-01]
 [6.56171142e-02 7.50449976e-01 2.84588408e-01]
 [0.00000000e+00 2.13820843e-02 0.00000000e+00]
 [5.80852680e-02 2.69454666e-05 6.79982649e-02]
 [7.51386126e-02 1.77573443e-01 2.19307794e-02]]
