In [1]:
import numpy as np
import warnings
from munch import Munch
import itertools

In [2]:
hidden_states = ['a','b','c']
emit_states = ['A','C']
hmm_transition = {}
for i in hidden_states:
    for j in hidden_states:
        hmm_transition[i,j] = 1/3

emit_mat = np.array([
    [.8,.2],
    [.5,.5],
    [.2,.8]
])
hmm_emit = {}
for i in range(3):
    for j in range(2):
        hmm_emit[hidden_states[i],emit_states[j]] = emit_mat[i,j].item()
        
hmm_startprob = {}
for i in hidden_states:
    hmm_startprob[i] = 1/3

hmm = Munch(states = hidden_states, emits = emit_states, tprob = hmm_transition, eprob = hmm_emit, initprob = hmm_startprob)

In [3]:
def mv_Viterbi(obs, hmm, cst, sat = True):
    '''
    Does Viterbii with intermediate variables. In this version, the constraint is included as an binary emission at the last time. 
    This formulation allows us to easily with inference in the case where the constraint is satisfied or not.
    
    obs: list of observed emissions
    hmm: Munch object containig our hmm object
    cast: Munch object containing our constraint (cst) object
    sat. Boolean determining whether the constraint is ture or not
    
    '''
    # if cst.aux_size == 1:
    #     aux_space = list([(True),(False)])
    # else:
    aux_space = list(itertools.product([True, False], repeat=cst.aux_size)) #constraint.aux_size
    val = {}

    for k in hmm.states:
        for r in aux_space:
            val[0,k,r] = cst.init_fun(k,r)*hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {}
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            for r in aux_space:
                max_val = -1 # set to dummy variable. will do brute-force search for max
                argmax = None #initialize argmax for ix_tracker
                for j in hmm.states:
                    for s in aux_space:
                        curr_val = val[t-1,j,s]*hmm.tprob[j,k]*cst.update_fun(r,j,s)
                        if curr_val > max_val:
                            max_val = curr_val
                            argmax = (j,s)
                if t == (len(obs)-1): #ie. at the last time we add in the constraint
                    val[t,k,r] = max_val*hmm.eprob[k,obs[t]]*cst.cst_fun(r,sat)
                else:
                    val[t,k,r] = max_val*hmm.eprob[k,obs[t]]
                ix_tracker[t-1,k,r] = argmax
    
    #Backward: compute the values of the optimal sequence
    max_val = -1
    best_state = None
    for k in hmm.states:
        for r in aux_space:
            curr_val = val[len(obs) - 1,k,r]
            if curr_val > max_val:
                max_val = curr_val
                best_state = (k,r)
    opt_augstate = [best_state]            
    opt_state = [best_state[0]]
    
    for t in range(len(obs) -1):
        best_state = ix_tracker[len(obs) - 2 -t,best_state[0], best_state[1]]
        opt_augstate = [best_state] + opt_augstate #append at the front
        opt_state = [best_state[0]] + opt_state

    return(opt_augstate, opt_state)

In [4]:
def update_fun(r,k , r_past):
    '''
    m1^t = tau^t_a = a OR tau^{t-1}_a #tracks if state a has happend yet 
    m2^t = [1- (1 - tau^t_a) AND c)] AND m2^{t-1} = [tau^t_a or (1 - c)] AND m2^{t-1} #tracks if the arrival time of a is before c
    k is the current state
    r is the auxillary state. a 2-tuple. r = (m1,m2)
    '''
    m1 = (k == 'a') or r_past[0]
    m2 = (m1 or (not k == 'c')) and r_past[1]

    return int(r == (m1,m2))

def init_fun(k, r):
    '''
    initial "prob" of r = (m1,m2) from k. is just indicator
    '''
    m1 = k == 'a'
    m2 = not k == 'c'

    return int(r == (m1,m2))
    
def cst_fun(r, sat):
    '''
    Constraint is a boolean emissions of the final auxillary state. In this case, is just m1^T: ie. tau_a >= tau_b for all time.
    '''
    return int(r[1] == sat) 

In [18]:
prec_cst = Munch(name = 'a occurs before c', aux_size = 2, update_fun = update_fun, init_fun = init_fun, cst_fun = cst_fun)
cst = prec_cst

In [5]:
def update_fun2(r,k , r_past):
    '''
    m1 = = tau_b or b . tracks if b has occured
    '''
    m1 = (k == 'b') or r_past[0]

    return int(r == (m1,))

def init_fun2(k, r):
    '''
    initial "prob" of r = m1,m2 from k. is just indicator
    '''
    m1 = k == 'b'

    return int(r == (m1,))
    
def cst_fun2(r, sat):
    '''
    Constraint is a boolean emissions of the final auxillary state. In this case
    '''
    
    return int(r[0]  == sat) 

In [7]:
occur_cst = Munch(name = 'b must occur', aux_size = 1, update_fun = update_fun2, init_fun = init_fun2, cst_fun = cst_fun2)
cst= occur_cst

In [117]:
from mv_Viterbi import mv_Viterbi

In [118]:
obs = ['A','A','C','A','A']

In [119]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, occur_cst, sat = True)

In [120]:
opt_state

['b', 'a', 'c', 'a', 'a']

In [121]:
def create_updatefun(zip_list):
    def update_fun_agg(r,k,r_past):
        val = 1
        for cst, ix in zip_list:
            val *= cst.update_fun(tuple(r[ix[0]:ix[1]]),k,tuple(r_past[ix[0]:ix[1]]))
        return val
    return update_fun_agg

def create_initfun(zip_list):
    def init_fun_agg(k,r):
        val = 1
        for cst,ix in zip_list:
            val *= cst.init_fun(k,tuple(r[ix[0]:ix[1]]))
        return val
    return init_fun_agg

def create_cstfun(zip_list):
    def cst_fun_agg(r,sat):
        val = 1
        it = 0
        for cst,ix in zip_list:
            val*= cst.cst_fun(tuple(r[ix[0]:ix[1]]),sat[it])
            it += 1
        return val
    return cst_fun_agg

In [165]:
def cst_aggregate(cst_list):
    l_ix = 0
    r_ix = 0
    ix_list = []
    name_list = []
    for cst in cst_list:
        r_ix = l_ix + cst.aux_size
        ix_list.append((l_ix,r_ix)) #tuple of indices of the aux stats that correspond to each state
        l_ix = r_ix
        name_list.append(cst.name)
    zip_list = list(zip(cst_list,ix_list))

    cst_combined = Munch(name = name_list, aux_size = r_ix, update_fun = create_updatefun(zip_list), \
                         init_fun = create_initfun(zip_list), cst_fun = create_cstfun(zip_list))
    return cst_combined

In [153]:
cst_combined, ix_list = cst_aggregate([prec_cst, occur_cst])

In [154]:
obs = ['A','A','C','A','A']

In [163]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, cst_combined, sat = [False,False])

In [164]:
opt_state

['c', 'a', 'c', 'a', 'a']

In [133]:
def mv_Viterbi(obs, hmm, cst, sat = True):
    '''
    Does Viterbii with intermediate variables. In this version, the constraint is included as an binary emission at the last time. 
    This formulation allows us to easily with inference in the case where the constraint is satisfied or not.
    
    obs: list of observed emissions
    hmm: Munch object containig our hmm object
    cast: Munch object containing our constraint (cst) object
    sat. Boolean determining whether the constraint is ture or not
    
    '''
    # if cst.aux_size == 1:
    #     aux_space = list([(True),(False)])
    # else:
    aux_space = list(itertools.product([True, False], repeat=cst.aux_size)) #constraint.aux_size
    val = {}

    for k in hmm.states:
        for r in aux_space:
            val[0,k,r] = cst.init_fun(k,r)*hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {}
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            for r in aux_space:
                max_val = -1 # set to dummy variable. will do brute-force search for max
                argmax = None #initialize argmax for ix_tracker
                for j in hmm.states:
                    for s in aux_space:
                        curr_val = val[t-1,j,s]*hmm.tprob[j,k]*cst.update_fun(r,j,s)
                        if curr_val > max_val:
                            max_val = curr_val
                            argmax = (j,s)
                if t == (len(obs)-1): #ie. at the last time we add in the constraint
                    val[t,k,r] = max_val*hmm.eprob[k,obs[t]]*cst.cst_fun(r,sat)
                else:
                    val[t,k,r] = max_val*hmm.eprob[k,obs[t]]
                ix_tracker[t-1,k,r] = argmax
    
    #Backward: compute the values of the optimal sequence
    max_val = -1
    best_state = None
    for k in hmm.states:
        for r in aux_space:
            curr_val = val[len(obs) - 1,k,r]
            if curr_val > max_val:
                max_val = curr_val
                best_state = (k,r)
    opt_augstate = [best_state]            
    opt_state = [best_state[0]]
    
    for t in range(len(obs) -1):
        best_state = ix_tracker[len(obs) - 2 -t,best_state[0], best_state[1]]
        opt_augstate = [best_state] + opt_augstate #append at the front
        opt_state = [best_state[0]] + opt_state

    return(opt_augstate, opt_state)

['a occurs before c', 'b must occur']

In [5]:
hmm = Munch(states = hidden_states, emits = emit_states, tprob = hmm_transition, eprob = hmm_emit, initprob = hmm_startprob)

In [6]:
cst = Munch(name = 'a occurs before c', aux_size = 2, update_fun = update_fun, init_fun = init_fun, cst_fun = cst_fun)

In [53]:
obs = ['A','A','C','A','A', 'C','C']

In [54]:
state_index

{'a': 0, 'b': 1, 'c': 2}

In [63]:
aux_space = list(itertools.product([True, False], repeat=cst.aux_size))
T = len(obs)
K = len(hmm.states)
M = len(aux_space)
sat = True

state_ix = {s: i for i, s in enumerate(hmm.states)}
aux_ix = {s: i for i, s in enumerate(aux_space)}

tmat = np.zeros((K,K))
initprob_vec = np.zeros(K)

for i in hmm.states:
    initprob_vec[state_ix[i]] = hmm.initprob[i]
    for j in hmm.states:
        tmat[state_ix[i],state_ix[j]] = hmm.tprob[i,j]

ind = np.zeros((M,K,M))
init_ind = np.zeros((M,K))
final_ind = np.zeros(M)

for r in aux_space:
    final_ind[aux_ix[r]] = cst.cst_fun(r,sat)
    for i in hmm.states:
        init_ind[aux_ix[r],state_ix[i]] = cst.init_fun(i,r)
        for s in aux_space:
            ind[aux_ix[r],state_ix[i],aux_ix[s]] = cst.update_fun(r,i,s)

emit_weights = np.zeros((T,K))
for t in range(T):
    emit_weights[t] = np.array([hmm.eprob[k,obs[t]] for k in hmm.states])

In [64]:
alpha = np.empty((T,K,M))
beta = np.empty(alpha.shape)

curr_emits = np.array([hmm.eprob[k,obs[1]] for k in hmm.states])
alpha[0] = np.einsum('i,i,ri -> ir',curr_emits, initprob_vec,init_ind)
beta[-1] = 1

In [65]:
#Compute the forward pass
for t in range(1,T):
    if t == (T-1):
        alpha[t] = np.einsum('i,ji,ris,js,r->ir', emit_weights[t], tmat, ind, alpha[t-1], final_ind)
    else:
        alpha[t] = np.einsum('i,ji,ris,js->ir', emit_weights[t], tmat, ind, alpha[t-1])

#Compute the backward pass
for t in range(1,T):
    if t == 1:
        beta[T-1-t] = np.einsum('js,j,ij,sjr,s->ir', beta[T-t],emit_weights[T-t],tmat,ind, final_ind)
    else:
        beta[T-1-t] = np.einsum('js,j,ij,sjr->ir', beta[T-t],emit_weights[T-t],tmat,ind)

In [66]:
for t in range(T):
    print(np.einsum('ir,ir->',alpha[t],beta[t]))

0.00588634545038866
0.00588634545038866
0.00588634545038866
0.005886345450388658
0.005886345450388659
0.005886345450388658
0.005886345450388658


In [71]:
emit_weights[1:].shape

(6, 3)

In [69]:
alpha[:(T-1)].shape

(6, 3, 4)

In [73]:
prob_data  = np.einsum('ir,ir->',alpha[0],beta[0]) #doesn't matter which time index. all give same
gamma = 1/prob_data*np.einsum('tir,tir->ti',alpha,beta)
xi = 1/prob_data*np.einsum('tjr,tk,jk,skr,tks->tjk',alpha[:(T-1)],emit_weights[1:],tmat,ind,beta[1:])

In [77]:
xi.sum(axis = 0).shape

(3, 3)

In [83]:
pi_opt = gamma[0]/gamma[0].sum()
tmat_opt = xi.sum(axis = 0)/xi.sum(axis = (0,2))[:,np.newaxis]

In [89]:
def compute_emitweights(obs,hmm):
    '''
    Separately handles the computation of the 
    '''
    T = len(obs)
    K = len(hmm.states)
    #Compute emissions weights for easier access
    emit_weights = np.zeros((T,K))
    for t in range(T):
        emit_weights[t] = np.array([hmm.eprob[k,obs[t]] for k in hmm.states])

    return emit_weights

In [94]:
def arrayConvert(obs, hmm, cst, sat):
    '''
    Converts/generates relevant parameters/weights into numpy arrays for Baum-Welch.
    By assumption, the update/emission parameters associated with the constraint are static.
    For now, fix the emission probabilities.
    Only the hmm paramters are being optimized.
    '''
    #Initialize and convert all quantities  to np.arrays
    aux_space = list(itertools.product([True, False], repeat=cst.aux_size))
    T = len(obs)
    K = len(hmm.states)
    M = len(aux_space)
    
    state_ix = {s: i for i, s in enumerate(hmm.states)}
    aux_ix = {s: i for i, s in enumerate(aux_space)}

    #Compute the hmm parameters
    tmat = np.zeros((K,K))
    init_prob = np.zeros(K)

    for i in hmm.states:
        init_prob[state_ix[i]] = hmm.initprob[i]
        for j in hmm.states:
            tmat[state_ix[i],state_ix[j]] = hmm.tprob[i,j]

    hmm_params = [tmat, init_prob]
    
    #Compute the cst parameters    
    ind = np.zeros((M,K,M))
    init_ind = np.zeros((M,K))
    final_ind = np.zeros(M)

    for r in aux_space:
        final_ind[aux_ix[r]] = cst.cst_fun(r,sat)
        for i in hmm.states:
            init_ind[aux_ix[r],state_ix[i]] = cst.init_fun(i,r)
            for s in aux_space:
                ind[aux_ix[r],state_ix[i],aux_ix[s]] = cst.update_fun(r,i,s)
                
    cst_params = [init_ind,final_ind,ind]
    
    return hmm_params, cst_params 

In [85]:
def mv_BaumWelch(hmm_params, emit_weights, cst_params):
    '''
    Baum-Welch algorithm that computes the moments in the M-step and returns the optimal init,tmat.
    Optimiziation of emissions will be handled separately since it's disribution-dependent. 
    Maybe can add functionality if it needs the posterior moments.
    
    IN
    hmm_params (list) = [tmat,init_prob]. list of np.arrays. note that the emit_weights need to be computed beforehand
        tmat: (K,K) init_prob: (K) 
    
    emit_weights. np.array of shape (T,K). the emission weights for each state. if updating emissions, need to recompute at every step too.
    
    cst_params (list) = [init_ind, final_ind, ind]. list of np.arrays. init/final_ind are handling first aux/final constraint emissions. ind is update.
        init_ind: (M,K) final_ind: (K) ind:(M,K,M)

    OUT

    the updated tmat, init_prob
    '''
    #Initialize and convert all quantities  to np.arrays
    tmat, init_prob = hmm_params
    init_ind, final_ind, ind = cst_params
    T = emit_weights.shape[0]
    K = emit_weights.shape[1]
    M = init_ind.shape[0]
    
    #Initialize first 
    alpha = np.empty((T,K,M))
    beta = np.empty(alpha.shape)
    
    alpha[0] = np.einsum('i,i,ri -> ir',emit_weights[0], init_prob,init_ind)
    beta[-1] = 1

    #Compute the forward pass
    for t in range(1,T):
        if t == (T-1):
            alpha[t] = np.einsum('i,ji,ris,js,r->ir', emit_weights[t], tmat, ind, alpha[t-1], final_ind)
        else:
            alpha[t] = np.einsum('i,ji,ris,js->ir', emit_weights[t], tmat, ind, alpha[t-1])
    
    #Compute the backward pass
    for t in range(1,T):
        if t == 1:
            beta[T-1-t] = np.einsum('js,j,ij,sjr,s->ir', beta[T-t],emit_weights[T-t],tmat,ind, final_ind)
        else:
            beta[T-1-t] = np.einsum('js,j,ij,sjr->ir', beta[T-t],emit_weights[T-t],tmat,ind)
    
    #Compute P(Y,C=c), probability of observing emissions AND the constraint in the specified truth configuration 
    prob_data  = np.einsum('ir,ir->',alpha[0],beta[0]) #doesn't matter which time index. all give same

    #Compute first/second moments in M step
    gamma = 1/prob_data*np.einsum('tir,tir->ti',alpha,beta)
    xi = 1/prob_data*np.einsum('tjr,tk,jk,skr,tks->tjk',alpha[:(T-1)],emit_weights[1:],tmat,ind,beta[1:])

    #Compute the optimal estimates
    pi_opt = gamma[0]/gamma[0].sum()
    tmat_opt = xi.sum(axis = 0)/xi.sum(axis = (0,2))[:,np.newaxis]

    return [tmat_opt,pi_opt], prob_data

In [97]:
def mv_EM(obs,hmm,cst,sat=True, conv_tol = 1e-10, max_iter = 1000, emit_opt = None):

    #Convert everything into numpy arrays
    old_hmm_params, old_cst_params = arrayConvert(obs, hmm, cst, sat)
    emit_weights = compute_emitweights(obs,hmm)
    conv  = 999
    it = 0
    while (conv > conv_tol) and (it <= max_iter):
        it += 1
        new_hmm_params, dat_prob = mv_BaumWelch(old_hmm_params, emit_weights, old_cst_params)
        if emit_opt:
            emit_opt(*args) #args to be passed in and defined later.
        conv = np.linalg.norm(new_hmm_params[0] - old_hmm_params[0]) #stopping criterion based on just transition matrix
        old_hmm_params = new_hmm_params
        
    return new_hmm_params

In [81]:
def mv_BaumWelch(obs, hmm, cst, sat = True, emit_opt = None):
    '''
    Baum-Welch algorithm that computes the moments in the M-step and returns the optimal init,tmat.
    If emissions also need to be optimized, then need to pass a optimizing function to emit_opt
    '''
    #Initialize and convert all quantities  to np.arrays
    aux_space = list(itertools.product([True, False], repeat=cst.aux_size))
    T = len(obs)
    K = len(hmm.states)
    M = len(aux_space)
    sat = True
    
    state_ix = {s: i for i, s in enumerate(hmm.states)}
    aux_ix = {s: i for i, s in enumerate(aux_space)}
    
    tmat = np.zeros((K,K))
    initprob_vec = np.zeros(K)
    
    for i in hmm.states:
        initprob_vec[state_ix[i]] = hmm.initprob[i]
        for j in hmm.states:
            tmat[state_ix[i],state_ix[j]] = hmm.tprob[i,j]
    
    ind = np.zeros((M,K,M))
    init_ind = np.zeros((M,K))
    final_ind = np.zeros(M)
    
    for r in aux_space:
        final_ind[aux_ix[r]] = cst.cst_fun(r,sat)
        for i in hmm.states:
            init_ind[aux_ix[r],state_ix[i]] = cst.init_fun(i,r)
            for s in aux_space:
                ind[aux_ix[r],state_ix[i],aux_ix[s]] = cst.update_fun(r,i,s)

    #Compute emissions weights for easier access
    emit_weights = np.zeros((T,K))
    for t in range(T):
        emit_weights[t] = np.array([hmm.eprob[k,obs[t]] for k in hmm.states])

    #Initialize first 
    alpha = np.empty((T,K,M))
    beta = np.empty(alpha.shape)
    
    curr_emits = np.array([hmm.eprob[k,obs[1]] for k in hmm.states])
    alpha[0] = np.einsum('i,i,ri -> ir',curr_emits, initprob_vec,init_ind)
    beta[-1] = 1

    #Compute the forward pass
    for t in range(1,T):
        if t == (T-1):
            alpha[t] = np.einsum('i,ji,ris,js,r->ir', emit_weights[t], tmat, ind, alpha[t-1], final_ind)
        else:
            alpha[t] = np.einsum('i,ji,ris,js->ir', emit_weights[t], tmat, ind, alpha[t-1])
    
    #Compute the backward pass
    for t in range(1,T):
        if t == 1:
            beta[T-1-t] = np.einsum('js,j,ij,sjr,s->ir', beta[T-t],emit_weights[T-t],tmat,ind, final_ind)
        else:
            beta[T-1-t] = np.einsum('js,j,ij,sjr->ir', beta[T-t],emit_weights[T-t],tmat,ind)
    
    #Compute P(Y,C=c), probability of observing emissions AND the constraint in the specified truth configuration 
    prob_data  = np.einsum('ir,ir->',alpha[0],beta[0]) #doesn't matter which time index. all give same

    #Compute first/second moments in M step
    gamma = 1/prob_data*np.einsum('tir,tir->ti',alpha,beta)
    xi = 1/prob_data*np.einsum('tjr,tk,jk,skr,tks->tjk',alpha[:(T-1)],emit_weights[1:],tmat,ind,beta[1:])

    #Compute the optimal estimates
    pi_opt = gamma[0]/gamma[0].sum()
    tmat_opt = xi.sum(axis = 0)/xi.sum(axis = (0,2))[:,np.newaxis]

    return pi_opt,tmat_opt,prob_data

In [None]:
def baum_welch(obs, hmm, cst, sat = True):
    '''
    Impelemnts Baum-Welch to compute the first/second moments in the E-step
    '''
    alpha = np.empty((len(obs), len(hmm.states), ))