In [1]:
from src.hmm import MultiCatEmissionHMM
import numpy as np

In [2]:
B_1 = np.array([
    [0.5, 0.2, 0.3],
    [0.0, 0.2, 0.8]
])

B_2 = np.array([
    [0.6, 0.4],
    [0.1, 0.9]
])

Bs = np.concatenate([B_1, B_2], axis=1)
Bs

array([[0.5, 0.2, 0.3, 0.6, 0.4],
       [0. , 0.2, 0.8, 0.1, 0.9]])

In [3]:
hmm = MultiCatEmissionHMM(
    init_A = np.array([
        [0.7, 0.3],
        [0.2, 0.8]
    ]),

    init_Bs = Bs,

    init_pi = np.array(
        [0.5, 0.5]
    ),

    num_emission_symbols = np.array(
        [3, 2]
    )
)

In [4]:
observations = np.array([[0, 1], [1, 0], [1, 1], [1, 0]]) # (D, not B), (C, B), (C, not B), (W, B)

In [5]:
hmm.predict(Ys=observations)

[array([1., 0.]),
 array([0.93333333, 0.06666667]),
 array([0.66666667, 0.33333333]),
 array([0.82222222, 0.17777778])]

In [6]:
hmm.viterbi(Ys=observations)

[0, 0, 0, 0]

In [7]:
from src.optim import ExpectationMaximization as EM

EM(hmm).optimize()

TypeError: Can't instantiate abstract class ExpectationMaximization with abstract method optimize

In [None]:
import numpy as np
from scipy.special import kl_div, softmax

gt = np.random.random(size=(20,)).astype(np.float64)
gt /= gt.sum()
gt.sum() 

1.0

In [None]:
params = np.random.random(size=(20,)).astype(np.float64)

def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

func = lambda x: kl_divergence(gt, softmax(x))


In [None]:
func(params)

0.31517604612656974

In [None]:
from scipy.optimize import minimize

optim_results = minimize(func, x0=params)

In [None]:
func(optim_results.x)

5.567756584516184e-09

In [None]:
params, softmax(optim_results.x), gt

(array([0.71070904, 0.32260829, 0.66607034, 0.69739536, 0.75224878,
        0.27679934, 0.67025646, 0.89531649, 0.59291052, 0.79779289,
        0.63210647, 0.96909607, 0.33489641, 0.20939145, 0.3725253 ,
        0.71812204, 0.64272915, 0.85419084, 0.11499961, 0.35119121]),
 array([0.01719333, 0.04145686, 0.06934084, 0.040187  , 0.00824475,
        0.05585652, 0.03086843, 0.03377262, 0.05731105, 0.13047017,
        0.10644186, 0.02736735, 0.0333719 , 0.05179446, 0.08887976,
        0.03114011, 0.04737967, 0.00028527, 0.11156971, 0.01706832]),
 array([0.01719738, 0.04145455, 0.06935031, 0.040189  , 0.00824458,
        0.05585038, 0.0308642 , 0.03377511, 0.05730762, 0.1304712 ,
        0.10643487, 0.02737304, 0.03337601, 0.05180269, 0.08887533,
        0.03113584, 0.04738003, 0.00028577, 0.11156947, 0.01706265]))