In [1]:
import os
import numpy as np
import pandas as pd
import scipy.stats as st
import numpy.random as rd
import matplotlib.pyplot as plt
from IPython.display import display
plt.style.use("ggplot")

%matplotlib inline

In [2]:
np.random.seed(0)
a = np.array([[0.5, 0.5], [0.5, 0.5]])
b = np.array([[0.5, 0.5], [0.5, 0.5]])
gamma = np.array([0.5, 0.5])
Tj = 10
s = [0, 1]
y = np.random.choice(s, Tj)
N1 = len(s)

In [3]:
y

array([0, 1, 1, 0, 1, 1, 1, 1, 1, 1])

In [4]:
# Baum-Welch(E-step)
def Baum_Welch(y, gamma, a, b):
    """
    gamma: 初期確率
    a: 遷移確率
    b: 出力確率
    
    alpha, beta, tau, tau_をBaum-Welchアルゴリズムで計算
    """
    alpha = np.zeros((Tj, N1))
    beta = np.zeros((Tj, N1))
    tau = np.zeros((Tj, N1, N1))
    tau_ = np.zeros((Tj, N1))
    
    # forwardアルゴリズム(alphaの計算)
    for t in range(Tj):
        for j in range(N1):
            if t==0:
                alpha[t, j] = gamma[j] * b[j, y[0]]
            else:
                for i in range(N1):
                    alpha[t, j] += alpha[t-1, i] * a[i, j]
                alpha[t, j] *= b[i, y[0]]
    
    # backwardアルゴリズム(betaの計算)
    for t in range(Tj)[::-1]:
        for i in range(N1):
            if t==(Tj-1):
                beta[t, i] = 1
            else:
                for j in range(N1):
                    beta[t, i] += a[i, j] * b[j, y[t+1]] * beta[t+1, j]

    # tauの計算
    for t in range(Tj-1):
        m = 0
        for i in range(N1):
            for j in range(N1):
                m += (alpha[t, i] * a[i, j] * b[j, y[t+1]] * beta[t+1, j])
        for i in range(N1):
            for j in range(N1):       
                tau[t, i, j] = (alpha[t, i] * a[i, j] * b[j, y[t+1]] * beta[t+1, j]) / m
    
    # tau_の計算
    for t in range(Tj):
        for i in range(N1):
            for j in range(N1):
                tau_[t, i] += tau[t, i, j]

    return alpha, beta, tau, tau_

In [5]:
alpha, beta, tau, tau_ = Baum_Welch(y, gamma, a, b)
print(alpha)
print(beta)
print(tau)
print(tau_)

[[[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.25  0.25]
  [ 0.25  0.25]]

 [[ 0.    0.  ]
  [ 0.    0.  ]]]
