In [24]:
import numpy as np

In [26]:
def forward(V, A, B, pi):
    """ 
    alpha_t(j) = sigma_i_to_N(alpha_t-1(i) * A_ij * B_j(O_t))
                    i to N is all possible states
                    t is the index of steps / time
                    j is a specific possible state
                    O_t is the observed visible state in time t
    returns T * N matrix
    each row represents specific time
    each column represents specific state
    """
    alpha = np.zeros((len(V), len(A)))
    for i in range(len(A)):
        alpha[0][i] = pi[i] * B[i, V[0]]
    for time_index in range(1, len(V)):
        for outer_state_index in range(len(A)):
            total = 0
            for inner_state_index in range(len(A)):
                total += alpha[time_index-1][inner_state_index] * A[inner_state_index][outer_state_index] * B[outer_state_index][V[time_index]]
            alpha[time_index][outer_state_index] = total
    return alpha

def backward(V, A, B):
    """
    beta_T(i) = 1
    beta_t(i) (from t = 1 to t = T - 1) = sigma_j_to_N(A_ij * B_j(V_t+1) * beta_t+1(j))
                t is the index of steps / time
                T is the last step index
                j to N is all possible states
                i, j are states
                O_t+1 is an observed value in given time
                B_j(O_to+1) is prob. of getting specific observation given we are in state j
                beta_t+1(j) is backward prob. in time t+1 of state j
    returns T * N matrix
    each row represents specific time
    each column represents specific state    
    """
    beta = np.zeros((len(V), len(A)))
    for i in range(len(A)):
        beta[len(V) - 1][i] = 1
    for time_index in range(len(V) - 2, -1, -1):
        for start_state_index in range(len(A)):
            total = 0
            for end_state_index in range(len(A)):
                total += A[start_state_index][end_state_index] * B[end_state_index][V[time_index+1]] * beta[time_index + 1][end_state_index]
            beta[time_index][start_state_index] = total
    return beta

def xi(V, A, B, pi, alpha, beta):
    """
    xi_t(ij) = alpha_t(i) * A_ij * B_j(O_t+1) * beta_t+1(j) / sigma_i_to_N(sigma_j_to_N(alpha_t(i) * A_ij * B_j(O_t+1) * beta_t+1(j)))
                t is the index of steps / time
                j to N, i to N are all possible states
                i, j are states
                O_t+1 is an observed value in given time
                B_j(O_to+1) is prob. of getting specific observation given we are in state j
                beta_t+1(j) is backward prob. in time t+1 of state j
                alpha_t(j) is forward prob. in time t of state j
    returns T - 1 * N * N matrix
    each 2d matrix represents probs of transition from i to j in specific time
    """
    not_quite_xi = np.zeros((len(V) - 1, len(A), len(A)))
    xi = np.zeros((len(V) - 1, len(A), len(A)))
    lower_term = np.zeros((len(V) - 1))
    for time_index in range(len(V) - 1):
        for start_state_index in range(len(A)):
            for end_state_index in range(len(A)):
                not_quite_xi[time_index][start_state_index][end_state_index] = alpha[time_index][start_state_index] * A[start_state_index][end_state_index] * B[end_state_index][V[time_index+1]] * beta[time_index+1][end_state_index]

    for time_index in range(len(V) - 1):
        total = 0
        for start_state_index in range(len(A)):
            for end_state_index in range(len(A)):
                total += not_quite_xi[time_index][start_state_index][end_state_index]
        lower_term[time_index] = total

    for time_index in range(len(V) - 1):
        for start_state_index in range(len(A)):
            for end_state_index in range(len(A)):
                xi[time_index][start_state_index][end_state_index] = not_quite_xi[time_index][start_state_index][end_state_index] / lower_term[time_index]
    return xi

def gamma_with_xi(xi):
    gamma = np.zeros((len(xi), len(xi[0])))
    for time_index in range(len(gamma)):
        for start_state_index in range(len(gamma[0])):
            local_state_total = 0
            for end_state_index in range(len(gamma[0])):
                local_state_total += xi[time_index][start_state_index][end_state_index]
            gamma[time_index][start_state_index] = local_state_total
    return gamma

def gamma(alpha, beta):
    """
    gamma_t(i) = alpha_t(i) * beta_t(i) / sigma_i_to_N(alpha_t(i) * beta_t(i))
                t is the index of steps / time
                i to N is all possible states
                i is a state
                beta_t(i) is backward prob. in time t+1 of state i
                alpha_t(i) is forward prob. in time t of state i
    returns T * N matrix
    each row represents specific time
    each column represents specific state
    """
    gamma = np.zeros((len(alpha), len(alpha[0])))
    for time_index in range(len(gamma)):
        lower_term = 0
        for state_index in range(len(alpha[0])):
            lower_term += alpha[time_index][state_index] * beta[time_index][state_index]
        for state_index in range(len(alpha[0])):
            gamma[time_index][state_index] = (alpha[time_index][state_index] * beta[time_index][state_index]) / lower_term
    return gamma

def pi_dash(gamma):
    """
    returns 1 * N matrix: the first column of gamma
    each cell represents prob. of starting at given state
    """
    first_column = gamma[:, 0]
    return first_column

def a_dash(xi, gamma):
    """
    a_ij = sigma_t_to_T-1(xi_t(ij)) / sigma_t_to_T-1(gamma_t(j))
            T is the last step index
            t is a step
    returns N * N matrix: updated transition matrix
    """
    a_dash = np.zeros((len(xi[0]), len(xi[0])))
    upper_term = np.zeros((len(xi[0]), len(xi[0])))
    lower_term = np.zeros((len(xi[0])))
    for start_state_index in range(len(xi[0])):
        for end_state_index in range(len(xi[0])):
            xi_local_time_total = 0
            for time_index in range(len(xi)):
                xi_local_time_total += xi[time_index][start_state_index][end_state_index]
            upper_term[start_state_index][end_state_index] = xi_local_time_total
    for state_index in range(len(xi[0])):
        gamma_local_time_total = 0
        for time_index in range(len(xi)):
            gamma_local_time_total += gamma[time_index][state_index]
        lower_term[state_index] = gamma_local_time_total
    for start_state_index in range(len(xi[0])):
        for end_state_index in range(len(xi[0])):
            a_dash[start_state_index][end_state_index] = upper_term[start_state_index][end_state_index] / lower_term[start_state_index]
    return a_dash

def b_dash(V, gamma, obs_no):
    """
    B_ij = sigma_t_to_T(1(O(t) = j) * gamma_t(i)) / sigma_t_to_T(gamma_t(i))
    returns N * O
    N is all possible states
    O is all possible observations
    """
    b_dash = np.zeros((len(gamma[0]), obs_no))
    for state_index in range(len(b_dash)):
        for obs_index in range(len(b_dash[0])):
            local_upper_total = 0
            local_lower_total = 0
            for time_index in range(len(gamma)):
                local_lower_total += gamma[time_index][state_index]
                if obs_index == V[time_index]:
                    local_upper_total += gamma[time_index][state_index]
            b_dash[state_index][obs_index] = local_upper_total / local_lower_total
    return b_dash
            
def baum_welch(V, A, B, pi, obs_no, n_iter=1, keep_pi=True):
    """
    returns updated A, B, pi using baum-welch algorithm
    """
    for iteration in range(n_iter):
        alpha = forward(V, A, B, pi)
        beta = backward(V, A, B)
        computed_xi =  xi(V, A, B, pi, alpha, beta)
        computed_gamma = gamma(alpha, beta)
        A = a_dash(computed_xi, computed_gamma)
        B = b_dash(V, computed_gamma, obs_no)
        if keep_pi == False:
            pi = pi_dash(computed_gamma)
    return A, B, pi

parthy_V_2 = np.array([0, 1, 0])
parthy_A_2 = np.array([[0.7, 0.3], [0.4, 0.6]])
parthy_B_2 = np.array([[0.5, 0.5], [0.6, 0.4]])
parthy_pi_2 = np.array([0.6, 0.4])
A, B, pi = baum_welch(parthy_V_2, parthy_A_2, parthy_B_2, parthy_pi_2, 2, n_iter=1)
print("our A")
print(A)
print("our B")
print(B)

our A
[[0.6959847  0.3040153 ]
 [0.40104621 0.59895379]]
our B
[[0.64615787 0.35384213]
 [0.69409623 0.30590377]]
