In [1]:
import numpy as np
import torch
import json
from munch import Munch
import itertools
from collections import defaultdict
import random
import copy
import pickle
import matplotlib.pyplot as plt

import apt_helper as ahlp
import apt_cst_aggregate as cagg
import mv_Viterbi as mv

In [2]:
import importlib
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [3]:
names = ['apt','bob','sally']
mu_list = [.8,.9,.9]
apt_hmm, bob_hmm, sally_hmm = ahlp.process_load(names, delay = mu_list)
user_list = [bob_hmm, sally_hmm]

In [4]:
cst_names = ['know_sally_exists','have_sally_credential', 'learn_where_data_stored', 'have_data_on_ds', 'have_data_on_hi', 'have_data_on_he']

cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

cst_list = cst_list
sat = len(cst_list) * (True,)

In [5]:
#Check if correctly loaded. probabilities should sum to 1.
for usr in user_list:
    usr_params = ahlp.hmm2numpy(usr)
    print(f'initprob:{usr_params[0].sum()}  tprob: {usr_params[1].sum(axis = 1)}  eprob: {usr_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [6]:
apt_params, ix_list = ahlp.hmm2numpy(apt_hmm, return_ix = True)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


# Sanity Check that base Viterbi works. Pure Emissions

In [7]:
pure_hidden, pure_emission = ahlp.simulation_apt(apt_hmm, ix_list = None, emit_inhom = False)

In [8]:
opt_state =  mv.mv_Viterbi(obs = pure_emission, hmm = apt_hmm)
for t in range(len(opt_state)):
    if opt_state[t] != pure_hidden[t]:
        print(f'mismatch at time {t} out of {len(opt_state) -1}') #no rule enforcing that that must end on POST

mismatch at time 65 out of 65


### Check that Numpy Implementation Gives Same Answer

In [9]:
apt_params[1].shape

(14, 14)

In [10]:
emit_weights = ahlp.compute_emitweights(pure_emission,apt_hmm, time_hom = True)
hmm_params = [apt_params[1], apt_params[0]]
opt_list = mv.Viterbi_numpy(hmm_params, emit_weights)
state_ix, _ = ix_list
state_ix = {v:k for k,v in state_ix.items()}
numpy_list = [state_ix[i] for i in opt_list]

In [11]:
num_correct = 0
for t in range(len(numpy_list)):
    if numpy_list[t] == opt_state[t]:
        num_correct += 1
print(f'proportion correct: {num_correct/len(numpy_list)}')

proportion correct: 1.0


## Check Numpy with Dummy Constraint Gives Same Answer

In [12]:
importlib.reload(mv)

<module 'mv_Viterbi' from '/home/fyqiu/Projects/conin/conin/mediation_variables/mv_Viterbi.py'>

In [13]:
cst_names = ['dummy_constraints']
cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

# cst_list = cst_list
sat = len(cst_list) * (True,)

apt_params, cst_params = ahlp.arrayConvert(apt_hmm, cst_list[0], sat = (True,))
opt_cst_list = mv.mv_Viterbi_numpy(apt_params, emit_weights, cst_params)
numpy_cst_list = [state_ix[state[0]] for state in opt_cst_list]

In [14]:
num_correct = 0
for t in range(len(numpy_list)):
    if numpy_list[t] == numpy_cst_list[t]:
        num_correct += 1
print(f'proportion correct: {num_correct/len(numpy_list)}')

proportion correct: 1.0


In [15]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [16]:
opt_augstateix_list = ahlp.Viterbi_torch_list(apt_hmm, cst_list, pure_emission, sat, time_hom = True, device = 'cuda:0')

In [17]:
obs = pure_emission
hmm = apt_hmm
time_hom = True
device = 'cuda:0'
dtype = torch.float32

In [24]:
hmm_params, cst_params = ahlp.arrayConvert(hmm,cst_list[0],sat)a

In [36]:
hmm_params[0].sum()

np.float64(14.0)

In [31]:
hmm_params_torch, cst_params_list_torch = ahlp.convertTensor_list(hmm,cst_list, sat, dtype = torch.float32, device = device)   


In [37]:
for t in range(len(hmm_params)):
    torch_param = hmm_params_torch[t].cpu().numpy()
    numpy_param = hmm_params[t]
    print(np.linalg.norm(torch_param - numpy_param))

4.5585133775756554e-08
0.0


In [18]:
    emit_weights = ahlp.compute_emitweights(obs, hmm, time_hom)

    #Generate hmm,cst params:
    hmm_params, cst_params_list = ahlp.convertTensor_list(hmm,cst_list, sat, dtype = dtype, device = device)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)


In [117]:
V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
V2 = torch.einsum('k,k -> k',emit_weights[0],init_prob).unsqueeze(-1).repeat(1,2)

V = V/V.max() #normalize for numerical stability
val[0] = V.cpu()


In [163]:
t= 1
V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
V = V.reshape(kr_shape + (-1,))
V = V/V.max()
print(V.sum())
max_ix = torch.argmax(V, axis = -1, keepdims = True)
ix_tracker[t-1] = max_ix.squeeze()
V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()


tensor(5.7143, device='cuda:0')


In [165]:
V.sum()

tensor(2., device='cuda:0')

# Numpy

In [None]:
    opt_augstateix_list = []
    
    tmat, init_prob = hmm_params
    init_ind,final_ind,ind = cst_params
    
    T = emit_weights.shape[0]
    K, M = init_ind.shape

    val = np.empty((T,K,M))
    ix_tracker = np.empty((T,K,M)) #will store flattened indices

    #Forward pass
    V = np.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    val[0] = V
    for t in range(1,T):
        V = np.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = V.reshape((K,M,-1))
        max_ix = np.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = np.take_along_axis(V, max_ix, axis=-1).squeeze()
        if t == T:
            val[t] = np.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
        else:
            val[t] = np.einsum('k,kr -> kr', emit_weights[t],V)


In [None]:
def mv_Viterbi_numpy(hmm_params, emit_weights, cst_params = None):
    '''
    numpy version. hmm_params, cst_params are list of numpy arrays
    '''
    if cst_params is None:
        return Viterbi_numpy(hmm_params, emit_weights)
    
    opt_augstateix_list = []
    
    tmat, init_prob = hmm_params
    init_ind,final_ind,ind = cst_params
    
    T = emit_weights.shape[0]
    K, M = init_ind.shape

    val = np.empty((T,K,M))
    ix_tracker = np.empty((T,K,M)) #will store flattened indices

    #Forward pass
    V = np.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    val[0] = V
    for t in range(1,T):
        V = np.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = V.reshape((K,M,-1))
        max_ix = np.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = np.take_along_axis(V, max_ix, axis=-1).squeeze()
        if t == T:
            val[t] = np.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
        else:
            val[t] = np.einsum('k,kr -> kr', emit_weights[t],V)
        

    #Backward pass

    #Initialize the last index
    max_ix = int(np.argmax(val[T-1]).item())
    max_k, max_r =divmod(max_ix, M)
    opt_augstateix_list = [(max_k,max_r)] + opt_augstateix_list

    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        max_k, max_r = divmod(max_ix, M)
        opt_augstateix_list = [(max_k,max_r)] + opt_augstateix_list

    return opt_augstateix_list


In [None]:
def Viterbi_torch_list(hmm, cst_list, obs, sat,  time_hom = True, dtype = torch.float16,  device = 'cpu'):
    '''
    
    '''
    #Generate emit_weights:
    emit_weights = compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list = convertTensor_list(hmm,cst_list, sat, dtype = dtype, device = device)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    #Forward pass
    # V = torch.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max() #normalize for numerical stability
    val[0] = V.cpu()
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
        V = V.reshape((K,) + tuple(dims_list) + (-1,))
        V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()
        

    #Backward pass
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list

    return opt_augstateix_list


# Sanity Check. Noiseless Tiered APT Equivalent to Original

In [28]:
tier_apt = ahlp.create_tiered_apt(apt_hmm)
apt_params = ahlp.hmm2numpy(tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.]


In [29]:
opt_state_tier =  mv.mv_Viterbi(obs = pure_emission, hmm = tier_apt)
num_correct = 0
for t in range(len(opt_state)):
    # if opt_state[t] != opt_state_tier[t][0]:
    #     print(f'mismatch at time {t} out of {len(opt_state) -1}') #no rule enforcing that that must end on POST
    if opt_state[t] == opt_state_tier[t][0]:
        num_correct += 1 
print(num_correct/len(opt_state))

1.0


# Noisy Simulations

In [79]:
apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)
apt_truth_states, apt_truth_emits = apt_truth

In [80]:
opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
num_correct = 0
for t in range(len(opt_state)):
    if opt_state[t] == apt_truth_states[t]:
        num_correct += 1 #no rule enforcing that that must end on POST
print(num_correct/len(opt_state))

0.5


In [25]:
B = 1000
accuracy_list = []
for b in range(B):
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)
    apt_truth_states, apt_truth_emits = apt_truth
    opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
    num_correct = 0
    for t in range(len(opt_state)):
        if opt_state[t] == apt_truth_states[t]:
            num_correct += 1 #no rule enforcing that that must end on POST
    if b % 100 == 0:
        print(b)
    accuracy_list.append(num_correct/len(opt_state))
print(f'average proportion correct is {sum(accuracy_list)/len(accuracy_list)}')

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
average proportion correct is 0.4515789464311288


# Noisy Simulations with Constraints

In [91]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [92]:
B = 1000
accuracy_list = []
for b in range(B):
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list, cst_list)
    apt_truth_states, apt_truth_emits = apt_truth
    opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
    num_correct = 0
    for t in range(len(opt_state)):
        if opt_state[t] == apt_truth_states[t]:
            num_correct += 1 #no rule enforcing that that must end on POST
    if b % 100 == 0:
        print(b)
    accuracy_list.append(num_correct/len(opt_state))
print(f'average proportion correct is {sum(accuracy_list)/len(accuracy_list)}')

0
100
200
300
400
500
600
700
800
900
average proportion correct is 0.3373810406888345


In [94]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

# Check if Noisy APT Is Better (Simulation with Constraints)

In [100]:
noisy_apt = ahlp.create_noisy_apt(apt_hmm, 1/3)
apt_params = ahlp.hmm2numpy(noisy_tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1.]


In [108]:
B = 1000
accuracy_list = []
for b in range(B):
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list, cst_list)
    apt_truth_states, apt_truth_emits = apt_truth
    opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = noisy_apt)
    num_correct = 0
    for t in range(len(opt_state)):
        if opt_state[t] == apt_truth_states[t]:
            num_correct += 1 #no rule enforcing that that must end on POST
    if b % 100 == 0:
        print(b)
    accuracy_list.append(num_correct/len(opt_state))
print(f'average proportion correct is {sum(accuracy_list)/len(accuracy_list)}')

0
100
200
300
400
500
600
700
800
900
average proportion correct is 0.07559917904947176


In [203]:
agg_cst, zip_list, cst_ix = cagg.apt_cst_aggregate(cst_list, debug = True)

In [214]:
tier_apt_mix, ix_list = ahlp.lapt_mixture(tier_apt, user_list, len(combined_emits), mix_weights = None, return_ix = True)

In [215]:
noisy_tier_apt = ahlp.create_noisy_apt(tier_apt, 1/3)

In [216]:
apt_params = ahlp.hmm2numpy(noisy_tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]


In [218]:
#Vanilla Viterbi is no constraint is included
opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = noisy_tier_apt)

In [129]:
def create_noisy_apt(apt_hmm, mix_param, tol = 1e-7):
    '''
    Original APT: X -> Y

    Tiered APT: (X,Y) -> hat{Y}

    
    '''
    apt = copy.deepcopy(apt_hmm) #deepcopy since there's still some funkiness going on.
    M = len(apt.emits) #number of hidden states
    
    #Creat noisy emissions matrix
    eprob = defaultdict(int)
    #For now, create a noiseless emission, where the emission of the APT is the observed emission
    for k in apt.states:
        for e in apt.emits:
            eprob[k,e] = mix_param*apt.eprob[k,e] + (1- mix_param)/M 
        
    new_apt = Munch(name = apt.name, states = apt.states, emits = apt.emits, tprob = apt.tprob, \
                       eprob = eprob, initprob = apt.initprob)
    
    if apt.mu:
        new_apt.mu = apt.mu

    return new_apt


In [95]:
combined_emits[:10]

[None,
 ('V', 'access/bob'),
 None,
 None,
 None,
 ('S', 'postfix/local'),
 ('HE', 'img/post'),
 None,
 ('DS', 'syslog/ls'),
 None]

In [96]:
apt_hmm.initprob['CA']

0

In [98]:
hmm = apt_hmm
obs = combined_emits[:10]

In [None]:
    val = {} #initialize value dictionary

    for k in hmm.states:
            val[0,k] = hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {} #this is used in the backwards step to identify the optimal sequence
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            max_val = -1 # set to dummy variable. will do brute-force search for max
            argmax = None #initialize argmax for ix_tracker
            for j in hmm.states:
                curr_val = val[t-1,j]*hmm.tprob[j,k]
                if curr_val > max_val:
                    max_val = curr_val
                    argmax = j
                    print(j)
            val[t,k] = max_val*hmm.eprob[k,obs[t]]
            ix_tracker[t-1,k] = argmax


In [100]:
val[0,'Pre']

KeyError: (0, 'Pre')

In [None]:
    val = {} #initialize value dictionary

    for k in hmm.states:
            val[0,k] = hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {} #this is used in the backwards step to identify the optimal sequence
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            max_val = -1 # set to dummy variable. will do brute-force search for max
            argmax = None #initialize argmax for ix_tracker
            for j in hmm.states:
                curr_val = val[t-1,j]*hmm.tprob[j,k]
                if curr_val > max_val:
                    max_val = curr_val
                    argmax = j
            val[t,k] = max_val*hmm.eprob[k,obs[t]]
            ix_tracker[t-1,k] = argmax
    
    #Backward: compute the values of the optimal sequence
    max_val = -1
    best_state = None
    for k in hmm.states:
        curr_val = val[len(obs) - 1,k]
        if curr_val > max_val:
            max_val = curr_val
            best_state = k
    opt_state = [best_state]
    
    for t in range(len(obs) -1):
        best_state = ix_tracker[len(obs) - 2 -t,best_state]
        opt_state = [best_state] + opt_state

    return opt_state


In [None]:
apt_hmm_mix.eprob

In [81]:
apt_truth[0]

['PRE',
 'PRE',
 'PRE',
 'PRE',
 'PRE',
 'IA',
 'EX',
 'WAIT_EX',
 'DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'EXF',
 'WAIT_EXF',
 'WAIT_EXF',
 'WA

In [80]:
opt_state

[('PRE', None),
 ('IA', ('S', 'postfix/local')),
 ('EX', ('S', 'postfix/local')),
 ('DI', ('DS', 'syslog/ls')),
 ('COL', ('HE', 'img/post')),
 ('EXF', ('HE', 'img/query')),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', N

In [None]:
len(tier_apt.hmm)

In [122]:
sat = (True,) * 3
opt_aug, opt_state =  mv.mv_Viterbi_v2(obs = combined_emits, hmm = tier_apt, cst= agg_cst,sat = sat)

1
2
3


KeyboardInterrupt: 

In [None]:
mv_Viterbi(obs, hmm, combined_cst, sat = (True,True))