# Implement Baum-Welch Learning

[ba10k](https://rosalind.info/problems/ba10k/)

## Baum-Welch Learning Problem

    Given:

A sequence of emitted symbols x = x1 . . . xn in an alphabet A, generated by a k-state HMM with unknown transition and emission probabilities, initial Transition and Emission matrices and a number of iterations I.

    Return:

A matrix of transition probabilities Transition and a matrix of emission probabilities Emission that maximizes Pr(x,π) over all possible transition and emission matrices and over all hidden paths π.

In [162]:
import numpy as np

In [163]:
def forward(x, T, E, initial_distribution):
    alpha = np.zeros((x.shape[0], T.shape[0]))
    alpha[0, :] = initial_distribution * E[:, x[0]]
    for t in range(1, x.shape[0]):
        for j in range(T.shape[0]):
            alpha[t, j] = alpha[t - 1].dot(T[:, j]) * E[j, x[t]]
    return alpha

In [164]:
def backward(x, T, E):
    beta = np.zeros((x.shape[0], T.shape[0]))
    beta[x.shape[0] - 1] = np.ones((T.shape[0]))
    for t in range(x.shape[0] - 2, -1, -1):
        for j in range(T.shape[0]):
            beta[t, j] = (beta[t + 1] * E[:, x[t + 1]]).dot(T[j, :])
    return beta

In [165]:
def baum_welch(x, T, E, initial_distribution, I):
    M = T.shape[0]
    seq_len = len(x)

    for n in range(I):
        alpha = forward(x, T, E, initial_distribution)
        beta = backward(x, T, E)

        xi = np.zeros((M, M, seq_len - 1))
        for t in range(seq_len - 1):
            denominator = np.dot(np.dot(alpha[t, :].T, T) * E[:, x[t + 1]].T, beta[t + 1, :])
            for i in range(M):
                numerator = alpha[t, i] * T[i, :] * E[:, x[t + 1]].T * beta[t + 1, :].T
                xi[i, :, t] = numerator / denominator

        gamma = np.sum(xi, axis=1)
        T = np.sum(xi, 2) / np.sum(gamma, axis=1).reshape((-1, 1))
        gamma = np.hstack((gamma, np.sum(xi[:, :, seq_len - 2], axis=0).reshape((-1, 1))))

        K = E.shape[1]
        denominator = np.sum(gamma, axis=1)
        for l in range(K):
            E[:, l] = np.sum(gamma[:, x == l], axis=1)

        E = np.divide(E, denominator.reshape((-1, 1)))
    return T, E

In [166]:
def parse_input(lines):
    I = int(lines[0].strip()) #number of iterations I
    x = lines[2].strip().split()[0] #sequence of emitted symbols 
    alphabet = lines[4].strip().split() #alphabet
    observations = np.array([int(alphabet.index(i)) for i in x])
    states = lines[6].strip().split() #states of HMM
    S = len(states)
    T = np.array([line.split()[1:] for line in lines[9:9+S]], float) #Transition matrices
    E = np.array([line.split()[1:] for line in lines[11+S:]], float) #Emission matrices
    return(observations, T, E, I, alphabet, states)

In [167]:
def print_output(matrix, rowNames, colNames):
    print('\t' + '\t'.join(colNames))
    for rowName, row in zip(rowNames, matrix):
        r = list(map(lambda x: str(x) if len(str(x)) < 6 else '%.3f' % x, row))
        print('\t'.join([rowName] + r))

In [168]:
with open("rosalind_ba10k.txt") as f:
    lines = f.readlines()
x, T, E, I, alphabet, states = parse_input(lines)
initial_distribution = np.array([0.5 for i in range(len(T))])
T, E = baum_welch(x, T, E, initial_distribution, I)
print_output(T, states, states)
print('--------')
print_output(E, states, alphabet)

	A	B	C	D
A	0.264	0.000	0.000	0.736
B	0.718	0.000	0.282	0.000
C	0.000	0.867	0.000	0.133
D	0.000	0.690	0.181	0.129
--------
	x	y	z
A	0.039	0.393	0.568
B	0.775	0.225	0.000
C	0.321	0.679	0.000
D	0.000	0.386	0.614
