In [9]:
import numpy as np
import warnings
from munch import Munch
import itertools

In [87]:
hidden_states = ['a','b','c']
emit_states = ['A','C']
hmm_transition = {}
for i in hidden_states:
    for j in hidden_states:
        hmm_transition[i,j] = 1/3

emit_mat = np.array([
    [.8,.2],
    [.5,.5],
    [.2,.8]
])
hmm_emit = {}
for i in range(3):
    for j in range(2):
        hmm_emit[hidden_states[i],emit_states[j]] = emit_mat[i,j].item()
        
hmm_startprob = {}
for i in hidden_states:
    hmm_startprob[i] = 1/3

hmm = Munch(states = hidden_states, emits = emit_states, tprob = hmm_transition, eprob = hmm_emit, initprob = hmm_startprob)

In [88]:
def mv_Viterbi(obs, hmm, cst, sat = True):
    '''
    Does Viterbii with intermediate variables. In this version, the constraint is included as an binary emission at the last time. 
    This formulation allows us to easily with inference in the case where the constraint is satisfied or not.
    
    obs: list of observed emissions
    hmm: Munch object containig our hmm object
    cast: Munch object containing our constraint (cst) object
    sat. Boolean determining whether the constraint is ture or not
    
    '''
    # if cst.aux_size == 1:
    #     aux_space = list([(True),(False)])
    # else:
    aux_space = list(itertools.product([True, False], repeat=cst.aux_size)) #constraint.aux_size
    val = {}

    for k in hmm.states:
        for r in aux_space:
            val[0,k,r] = cst.init_fun(k,r)*hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {}
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            for r in aux_space:
                max_val = -1 # set to dummy variable. will do brute-force search for max
                argmax = None #initialize argmax for ix_tracker
                for j in hmm.states:
                    for s in aux_space:
                        curr_val = val[t-1,j,s]*hmm.tprob[j,k]*cst.update_fun(r,j,s)
                        if curr_val > max_val:
                            max_val = curr_val
                            argmax = (j,s)
                if t == (len(obs)-1): #ie. at the last time we add in the constraint
                    val[t,k,r] = max_val*hmm.eprob[k,obs[t]]*cst.cst_fun(r,sat)
                else:
                    val[t,k,r] = max_val*hmm.eprob[k,obs[t]]
                ix_tracker[t-1,k,r] = argmax
    
    #Backward: compute the values of the optimal sequence
    max_val = -1
    best_state = None
    for k in hmm.states:
        for r in aux_space:
            curr_val = val[len(obs) - 1,k,r]
            if curr_val > max_val:
                max_val = curr_val
                best_state = (k,r)
    opt_augstate = [best_state]            
    opt_state = [best_state[0]]
    
    for t in range(len(obs) -1):
        best_state = ix_tracker[len(obs) - 2 -t,best_state[0], best_state[1]]
        opt_augstate = [best_state] + opt_augstate #append at the front
        opt_state = [best_state[0]] + opt_state

    return(opt_augstate, opt_state)

In [94]:
def update_fun(r,k , r_past):
    '''
    m1^t = tau^t_a = a OR tau^{t-1}_a #tracks if state a has happend yet 
    m2^t = [1- (1 - tau^t_a) AND c)] AND m2^{t-1} = [tau^t_a or (1 - c)] AND m2^{t-1} #tracks if the arrival time of a is before c
    k is the current state
    r is the auxillary state. a 2-tuple. r = (m1,m2)
    '''
    m1 = (k == 'a') or r_past[0]
    m2 = (m1 or (not k == 'c')) and r_past[1]

    return int(r == (m1,m2))

def init_fun(k, r):
    '''
    initial "prob" of r = (m1,m2) from k. is just indicator
    '''
    m1 = k == 'a'
    m2 = not k == 'c'

    return int(r == (m1,m2))
    
def cst_fun(r, sat):
    '''
    Constraint is a boolean emissions of the final auxillary state. In this case, is just m1^T: ie. tau_a >= tau_b for all time.
    '''
    return int(r[1] == sat) 

In [95]:
prec_cst = Munch(name = 'a occurs before c', aux_size = 2, update_fun = update_fun, init_fun = init_fun, cst_fun = cst_fun)

In [115]:
def update_fun2(r,k , r_past):
    '''
    m1 = = tau_b or b . tracks if b has occured
    '''
    m1 = (k == 'b') or r_past[0]

    return int(r == (m1,))

def init_fun2(k, r):
    '''
    initial "prob" of r = m1,m2 from k. is just indicator
    '''
    m1 = k == 'b'

    return int(r == (m1,))
    
def cst_fun2(r, sat):
    '''
    Constraint is a boolean emissions of the final auxillary state. In this case
    '''
    
    return int(r[0]  == sat) 

In [116]:
occur_cst = Munch(name = 'b must occur', aux_size = 1, update_fun = update_fun2, init_fun = init_fun2, cst_fun = cst_fun2)

In [117]:
from mv_Viterbi import mv_Viterbi

In [118]:
obs = ['A','A','C','A','A']

In [119]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, occur_cst, sat = True)

In [120]:
opt_state

['b', 'a', 'c', 'a', 'a']

In [121]:
def create_updatefun(zip_list):
    def update_fun_agg(r,k,r_past):
        val = 1
        for cst, ix in zip_list:
            val *= cst.update_fun(tuple(r[ix[0]:ix[1]]),k,tuple(r_past[ix[0]:ix[1]]))
        return val
    return update_fun_agg

def create_initfun(zip_list):
    def init_fun_agg(k,r):
        val = 1
        for cst,ix in zip_list:
            val *= cst.init_fun(k,tuple(r[ix[0]:ix[1]]))
        return val
    return init_fun_agg

def create_cstfun(zip_list):
    def cst_fun_agg(r,sat):
        val = 1
        it = 0
        for cst,ix in zip_list:
            val*= cst.cst_fun(tuple(r[ix[0]:ix[1]]),sat[it])
            it += 1
        return val
    return cst_fun_agg

In [165]:
def cst_aggregate(cst_list):
    l_ix = 0
    r_ix = 0
    ix_list = []
    name_list = []
    for cst in cst_list:
        r_ix = l_ix + cst.aux_size
        ix_list.append((l_ix,r_ix)) #tuple of indices of the aux stats that correspond to each state
        l_ix = r_ix
        name_list.append(cst.name)
    zip_list = list(zip(cst_list,ix_list))

    cst_combined = Munch(name = name_list, aux_size = r_ix, update_fun = create_updatefun(zip_list), \
                         init_fun = create_initfun(zip_list), cst_fun = create_cstfun(zip_list))
    return cst_combined

In [153]:
cst_combined, ix_list = cst_aggregate([prec_cst, occur_cst])

In [154]:
obs = ['A','A','C','A','A']

In [163]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, cst_combined, sat = [False,False])

In [164]:
opt_state

['c', 'a', 'c', 'a', 'a']

['a occurs before c', 'b must occur']