In [78]:
import numpy as np
import torch
import json
from munch import Munch
import itertools
from collections import defaultdict
import random
import copy
import pickle
import matplotlib.pyplot as plt
import importlib

import apt_helper as ahlp
import apt_cst_aggregate as cagg
import mv_Viterbi as mv

In [229]:
import importlib
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [230]:
names = ['apt','bob','sally']
mu_list = [.8,.9,.9]
apt_hmm, bob_hmm, sally_hmm = ahlp.process_load(names, delay = mu_list)
user_list = [bob_hmm, sally_hmm]

In [231]:
cst_names = ['know_sally_exists','have_sally_credential', 'learn_where_data_stored', 'have_data_on_ds', 'have_data_on_hi', 'have_data_on_he']

cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

cst_list = cst_list
sat = len(cst_list) * (True,)

In [232]:
#Check if correctly loaded. probabilities should sum to 1.
for usr in user_list:
    usr_params = ahlp.hmm2numpy(usr)
    print(f'initprob:{usr_params[0].sum()}  tprob: {usr_params[1].sum(axis = 1)}  eprob: {usr_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [233]:
apt_params, ix_list = ahlp.hmm2numpy(apt_hmm, return_ix = True)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


# Sanity Check that base Viterbi works. Pure Emissions

In [234]:
len(apt_hmm.eprob.keys())

25

In [235]:
pure_hidden, pure_emission = ahlp.simulation_apt(apt_hmm, ix_list = None, emit_inhom = False)

In [236]:
len(apt_hmm.eprob.keys())

25

In [237]:
opt_state =  mv.mv_Viterbi(obs = pure_emission, hmm = apt_hmm)
for t in range(len(opt_state)):
    if opt_state[t] != pure_hidden[t]:
        print(f'mismatch at time {t} out of {len(opt_state) -1}') #no rule enforcing that that must end on POST

mismatch at time 114 out of 138
mismatch at time 115 out of 138
mismatch at time 116 out of 138
mismatch at time 117 out of 138
mismatch at time 118 out of 138
mismatch at time 119 out of 138
mismatch at time 120 out of 138
mismatch at time 121 out of 138
mismatch at time 122 out of 138
mismatch at time 123 out of 138
mismatch at time 124 out of 138
mismatch at time 125 out of 138
mismatch at time 126 out of 138
mismatch at time 127 out of 138
mismatch at time 128 out of 138
mismatch at time 129 out of 138
mismatch at time 130 out of 138
mismatch at time 131 out of 138
mismatch at time 132 out of 138
mismatch at time 133 out of 138
mismatch at time 134 out of 138
mismatch at time 135 out of 138
mismatch at time 136 out of 138
mismatch at time 137 out of 138


In [238]:
len(apt_hmm.eprob.keys())

25

### Check that Numpy Implementation Gives Same Answer

In [239]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [240]:
emit_weights = ahlp.compute_emitweights(pure_emission,apt_hmm, time_hom = True)
hmm_params = [apt_params[1], apt_params[0]]
opt_list = mv.Viterbi_numpy(hmm_params, emit_weights)
state_ix, _ = ix_list
state_ix = {v:k for k,v in state_ix.items()}
numpy_list = [state_ix[i] for i in opt_list]

In [241]:
num_correct = 0
for t in range(len(numpy_list)):
    if numpy_list[t] == opt_state[t]:
        num_correct += 1
print(f'proportion correct: {num_correct/len(numpy_list)}')

proportion correct: 1.0


In [242]:
len(apt_hmm.eprob.keys())

25

In [243]:
numpy_list[:10]

['PRE',
 'PRE',
 'IA',
 'WAIT_IA',
 'EX',
 'WAIT_EX',
 'WAIT_EX',
 'WAIT_EX',
 'WAIT_EX',
 'EX']

## Check Numpy with Dummy Constraint Gives Same Answer

In [244]:
cst_names = ['dummy_constraint']
dummy_cst_list =  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    dummy_cst_list.append(curr_cst)

# cst_list = cst_list
dummy_sat = len(dummy_cst_list) * (True,)

apt_params, cst_params = ahlp.arrayConvert(apt_hmm, dummy_cst_list[0], sat = (True,))
opt_cst_list = mv.mv_Viterbi_numpy(apt_params, emit_weights, cst_params)
numpy_cst_list = [state_ix[state[0]] for state in opt_cst_list]

In [245]:
num_correct = 0
for t in range(len(numpy_list)):
    if numpy_list[t] == numpy_cst_list[t]:
        num_correct += 1
print(f'proportion correct: {num_correct/len(numpy_list)}')

proportion correct: 1.0


In [246]:
len(apt_hmm.eprob.keys())

25

# Sanity Check. Torch Version with DUmmy Constraint also Gives Same Answer

In [247]:
opt_torch_list, _ = ahlp.Viterbi_torch_list(apt_hmm, dummy_cst_list, pure_emission, dummy_sat, time_hom = True, device = 'cuda:0')
num_correct = [torch == numpy for torch, numpy in zip(opt_torch_list,numpy_cst_list)]
print(f'proportion correct: {len(num_correct)/len(num_correct)}')

proportion correct: 1.0


# Sanity Check. Noiseless Tiered APT Equivalent to Original

In [248]:
tier_apt = ahlp.create_tiered_apt(apt_hmm)
apt_params = ahlp.hmm2numpy(tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]


In [249]:
opt_state_tier =  mv.mv_Viterbi(obs = pure_emission, hmm = tier_apt)
num_correct = 0
for t in range(len(opt_state)):
    # if opt_state[t] != opt_state_tier[t][0]:
    #     print(f'mismatch at time {t} out of {len(opt_state) -1}') #no rule enforcing that that must end on POST
    if opt_state[t] == opt_state_tier[t][0]:
        num_correct += 1 
print(num_correct/len(opt_state))

1.0


In [250]:
len(apt_hmm.eprob.keys())

25

# Noisy Simulations

In [251]:
apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)
apt_truth_states, apt_truth_emits = apt_truth

In [252]:
opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
num_correct = 0
for t in range(len(opt_state)):
    if opt_state[t] == apt_truth_states[t]:
        num_correct += 1 #no rule enforcing that that must end on POST
print(num_correct/len(opt_state))

0.7741935483870968


In [33]:
B = 1000
accuracy_list = []
for b in range(B):
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)
    apt_truth_states, apt_truth_emits = apt_truth
    opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
    num_correct = [opt == truth for opt,truth in zip(opt_state,apt_truth_states)]
    if b % 100 == 0:
        print(b)
    accuracy_list.append(sum(num_correct)/len(num_correct))
print(f'average proportion correct is {sum(accuracy_list)/len(accuracy_list)}')

0
100


KeyboardInterrupt: 

In [253]:
len(apt_hmm.eprob.keys())

25

# Noisy Constrained Simulations 

In [32]:
B = 1000
accuracy_list = []
for b in range(B):
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list, cst_list)
    apt_truth_states, apt_truth_emits = apt_truth
    opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
    if not ahlp.check_valid(*apt_truth, cst_list):
        print('Invalid sequence')
    num_correct = [opt == truth for opt,truth in zip(opt_state,apt_truth_states)]
    if b % 100 == 0:
        print(b)
    accuracy_list.append(sum(num_correct)/len(num_correct))
print(f'average proportion correct is {sum(accuracy_list)/len(accuracy_list)}')

0
100


KeyboardInterrupt: 

# Check if Noisy APT Is Better (Simulation with Constraints)

In [254]:
noisy_apt = ahlp.create_noisy_apt(apt_hmm, 1/3)
apt_params = ahlp.hmm2numpy(noisy_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [255]:
len(apt_hmm.eprob.keys())

25

In [221]:
B = 1000
accuracy_list = []
for b in range(B):
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list, cst_list)
    apt_truth_states, apt_truth_emits = apt_truth
    opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = noisy_apt)
    num_correct = [opt == truth for opt,truth in zip(opt_state,apt_truth_states)]
    if b % 100 == 0:
        print(b)
    accuracy_list.append(sum(num_correct)/len(num_correct))
print(f'average proportion correct is {sum(accuracy_list)/len(accuracy_list)}')

0


KeyboardInterrupt: 

In [256]:
cst_names = ['know_sally_exists','have_sally_credential', 'learn_where_data_stored', 'have_data_on_ds', 'have_data_on_hi', 'have_data_on_he']
cst_names = [names + '_TRUE' for names in cst_names]
cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

sat = len(cst_list) * (True,)

In [268]:
apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list, cst_list)
apt_truth_states, apt_truth_emits = apt_truth


In [273]:
len(apt_hmm.eprob.keys())

25

In [274]:
importlib.reload(mv)

<module 'mv_Viterbi' from '/home/fyqiu/Projects/conin/conin/mediation_variables/mv_Viterbi.py'>

In [278]:
agg_cst = cagg.apt_cst_aggregate(cst_list[:2])
opt_state =  mv.mv_Viterbi_v2(obs = combined_emits, hmm = apt_hmm, cst = agg_cst, sat = sat)
opt_state[:10]

['PRE',
 'IA',
 'WAIT_IA',
 'WAIT_IA',
 'WAIT_IA',
 'WAIT_IA',
 'WAIT_IA',
 'WAIT_IA',
 'WAIT_IA',
 'WAIT_IA']

In [279]:
opt_state_unc = mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)

In [280]:
num_correct = [cst == ucst for cst,ucst in zip(opt_state,opt_state_unc)]
print(sum(num_correct)/len(num_correct))

1.0


In [22]:
agg_cst = cagg.apt_cst_aggregate(cst_list)

In [23]:
tier_apt_mix, ix_list = ahlp.lapt_mixture(tier_apt, user_list, len(combined_emits), mix_weights = None, return_ix = True)

In [24]:
noisy_tier_apt = ahlp.create_noisy_apt(tier_apt, 1/3)

In [25]:
apt_params = ahlp.hmm2numpy(noisy_tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [218]:
#Vanilla Viterbi is no constraint is included
opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = noisy_tier_apt)

# Work

In [102]:
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list, cst_list)


In [298]:
cst_names = ['know_sally_exists','have_sally_credential', 'learn_where_data_stored', 'have_data_on_ds', 'have_data_on_hi', 'have_data_on_he']
cst_names = [names + '_TRUE' for names in cst_names]
cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

# cst_list = cst_list[:1]
sat = len(cst_list) * (True,)
agg_cst = cagg.apt_cst_aggregate(cst_list)

In [299]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [282]:
hmm_params, cst_params = ahlp.convertTensor_list(tier_apt, cst_list, sat, dtype = torch.float16, device = 'cpu', return_ix = False)

In [283]:
dims_list, init_ind_list,final_ind_list,ind_list = cst_params

In [284]:
hmm_params_np, cst_param1 = ahlp.arrayConvert(tier_apt, cst_list[0], sat[0])

In [285]:
for i in range(2):
    err = hmm_params_np[i] - hmm_params[i].cpu().numpy()
    print(np.linalg.norm(err))

0.0009636887170864863
0.0


In [286]:
for i in range(len(cst_list)):
    _, cst_param1 = ahlp.arrayConvert(tier_apt, cst_list[i], sat[i])
    init_ind,final_ind,ind = cst_param1
    init_err = np.linalg.norm(init_ind - init_ind_list[2*i].cpu().numpy()).item()
    final_err = np.linalg.norm(final_ind - final_ind_list[2*i].cpu().numpy()).item()
    ind_err = np.linalg.norm(ind - ind_list[2*i].cpu().numpy()).item()

    print(f'init_err: {init_err} final_err {final_err} ind_err {ind_err}')

init_err: 0.0 final_err 0.0 ind_err 0.0
init_err: 0.0 final_err 0.0 ind_err 0.0
init_err: 0.0 final_err 0.0 ind_err 0.0
init_err: 0.0 final_err 0.0 ind_err 0.0
init_err: 0.0 final_err 0.0 ind_err 0.0
init_err: 0.0 final_err 0.0 ind_err 0.0


In [368]:
apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list, cst_list)
apt_truth_states, apt_truth_emits = apt_truth

In [372]:
opt_state_list, opt_augstateix_list = ahlp.Viterbi_torch_list(tier_apt, dummy_cst_list, combined_emits, dummy_sat, time_hom = True, dtype = torch.float16, device = 'cuda:0')

In [373]:
opt_state_list[:10]

[('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None)]

In [369]:
opt_state_list, opt_augstateix_list = ahlp.Viterbi_torch_list(tier_apt, cst_list, apt_truth_emits, sat, time_hom = True, dtype = torch.float16, device = 'cuda:0')

In [371]:
opt_state_list[:10]

[('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None)]

In [354]:
opt_state_list[:10]

[('PRE', None),
 ('PRE', None),
 ('IA', ('S', 'postfix/local')),
 ('WAIT_IA', None),
 ('EX', ('HE', 'img/post')),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('EX', ('S', 'postfix/local'))]

In [351]:
opt_state_list[:10]

[('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None),
 ('PRE', None)]

In [347]:
opt_state_list[:15]

[('PRE', None),
 ('PRE', None),
 ('IA', ('S', 'postfix/local')),
 ('WAIT_IA', None),
 ('EX', ('HE', 'img/post')),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('EX', ('S', 'postfix/local')),
 ('EX', ('HE', 'img/post')),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None)]

In [359]:
noisy_apt = ahlp.create_noisy_apt(tier_apt, 1/3)

In [305]:
opt_state_list, opt_augstateix_list = ahlp.Viterbi_torch_list(apt_hmm, dummy_cst_list, combined_emits, dummy_sat, time_hom = True, dtype = torch.float16, device = 'cuda:0')

In [310]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [313]:
hmm_params, cst_params = ahlp.convertTensor_list(noisy_apt, dummy_cst_list, dummy_sat, dtype = torch.float16, device = 'cpu', return_ix = False)

In [324]:
cst_params[3][1]#.shape

[0, 1, 2, 3]

In [343]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [339]:
obs = pure_emission
hmm = tier_apt
time_hom = True
device = 'cuda:0'
dtype = torch.float32

In [340]:
    emit_weights = ahlp.compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list, state_ix = ahlp.convertTensor_list(hmm,cst_list, sat, dtype = dtype, \
                                                               device = device, return_ix = True)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    js_indices = [k + C + 1 for k in kr_indices]


In [341]:
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max() #normalize for numerical stability
    val[0] = V.cpu()
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), js_indices, tmat, [C+1,0], *ind_list, list(range(2*C + 2)))
        V = V.reshape(tuple(kr_shape) + (-1,))
        # V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()
        
    state_ix = {v:k for k,v in state_ix.items()}
    #Backward pass
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list

    opt_state_list = [state_ix[k[0]] for k in opt_augstateix_list]


In [342]:
opt_state_list

[('PRE', None),
 ('PRE', None),
 ('IA', ('S', 'postfix/local')),
 ('WAIT_IA', None),
 ('EX', ('HE', 'img/post')),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('EX', ('S', 'postfix/local')),
 ('EX', ('HE', 'img/post')),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('CA', ('HI', 'usr/query')),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('DI', ('DS', 'syslog/ls')),
 ('WAIT_DI', None),
 ('WAIT_DI', None),
 ('WAIT_DI', None),
 ('CA', ('HI', 'usr/query')),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('EX', ('DS', 'syslog/nano')),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('WAIT_EX', None),
 ('CA', ('HI', 'usr/query')),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('WAIT_CA', None),
 ('WAIT_CA',

In [83]:
t= 1
V = torch.einsum(val[t-1].to(device), js_indices, tmat, [C+1,0], *ind_list, list(range(2*C + 2)))
V = V.reshape(kr_shape + (-1,))
max_ix = torch.argmax(V, axis = -1, keepdims = True)
ix_tracker[t-1] = max_ix.squeeze()
V_reduce = torch.take_along_dim(V, max_ix, axis=-1).squeeze()


In [187]:
t = 1
V2 = torch.einsum('js,jk,krjs -> krjs',val[t-1].to(device),tmat, ind_list[0])

In [190]:
ind_list[1]

[0, 1, 2, 3]

In [195]:
kr_shape

(14, 2)

In [198]:
[k + C + 1 for k in kr_indices]

[2, 3]

In [None]:
def convertTensor_list(hmm, cst_list, sat, dtype = torch.float16, device = 'cpu'):
    '''
    cst_list is a list of the individual csts.
    '''
    #Initialize and convert all quantities  to np.arrays
    K = len(hmm.states)
    assert len(cst_list) == len(sat)
    
    state_ix = {s: i for i, s in enumerate(hmm.states)}

    #Compute the hmm parameters
    tmat = torch.zeros((K,K), dtype=dtype ).to(device)
    init_prob = torch.zeros(K, dtype=dtype ).to(device)

    for i in hmm.states:
        init_prob[state_ix[i]] = hmm.initprob[i]
        for j in hmm.states:
            tmat[state_ix[i],state_ix[j]] = hmm.tprob[i,j]

    hmm_params = [tmat, init_prob]
    
    #Compute the cst parameters 
    init_ind_list = []
    final_ind_list = []
    ind_list = []
    dims_list = []
    cst_dict = {cst.name,i for i,cst in cst_list}
    C = len(cst_list)
    cst_ix = 0
    for cst in cst_list:
        
        aux_space = list(itertools.product([True, False], repeat=cst.aux_size))
        aux_ix = {s: i for i, s in enumerate(aux_space)}
        M = len(aux_space)
        ind = torch.zeros((K,M,K,M),dtype=dtype ).to(device)
        init_ind = torch.zeros((K,M),dtype=dtype ).to(device)
        final_ind = torch.zeros((K,M),dtype=dtype ).to(device)
    
        for r in aux_space:
            for k in hmm.states:
                final_ind[state_ix[k], aux_ix[r]] = cst.cst_fun(k,r,sat)
                init_ind[state_ix[k],aux_ix[r]] = cst.init_fun(k,r)
                for s in aux_space:
                    for j in hmm.states:
                        ind[state_ix[k],aux_ix[r],state_ix[j],aux_ix[s]] = cst.update_fun(k,r,j,s)

        #indices are [0 = k,  (1 dim for each cst r_i = i + 1)  0 <= i <= n - 1 
        # init_ind_list.append((init_ind,[0,cst_ix + 1]))
        # final_ind_list.append((final_ind, [0, cst_ix + 1]))
        # #indices are [0 = k,(1 dim for each cst r_i = i + 1), n + 1 = j, (1 dim for s_i = i+n+2)] 
        # ind_list.append((ind, [0, cst_ix + 1, C + 1, cst_ix + C + 2]))
        # dims_list.append(M)

        init_ind_list += [init_ind,[0,cst_ix + 1]]
        final_ind_list += [final_ind, [0, cst_ix + 1]]
        #indices are [0 = k,(1 dim for each cst r_i = i + 1), n + 1 = j, (1 dim for s_i = i+n+2)] 
        ind_list += [ind, [0, cst_ix + 1, C + 1, cst_ix + C + 2]]
        dims_list.append(M)
        cst_ix += 1
                
    cst_params = [dims_list, init_ind_list,final_ind_list,ind_list]
    
    return hmm_params, cst_params 


In [200]:
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), [k + C + 1 for k in kr_indices], tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
        V = V.reshape((K,) + tuple(dims_list) + (-1,))
        V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()


In [None]:
def Viterbi_torch_list(hmm, cst_list, obs, sat,  time_hom = True, dtype = torch.float16,  device = 'cpu'):
    '''
    
    '''
    hmm = copy.deepcopy(hmm) #protect again in place modification
    #Generate emit_weights:
    emit_weights = compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list, state_ix = convertTensor_list(hmm,cst_list, sat, dtype = dtype, \
                                                               device = device, return_ix = True)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    js_indices = [k + C + 1 for k in kr_indices]

    #Forward pass
    # V = torch.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max() #normalize for numerical stability
    val[0] = V.cpu()
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), js_indices, tmat, [C+1,0], *ind_list, list(range(2*C + 2)))
        V = V.reshape(tuple(kr_shape) + (-1,))
        # V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()
        
    state_ix = {v:k for k,v in state_ix.items()}
    #Backward pass
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list

    opt_state_list = [state_ix[k[0]] for k in opt_augstateix_list]
    return opt_state_list, opt_augstateix_list


In [None]:
def mv_Viterbi_numpy(hmm_params, emit_weights, cst_params = None):
    '''
    numpy version. hmm_params, cst_params are list of numpy arrays
    '''
    if cst_params is None:
        return Viterbi_numpy(hmm_params, emit_weights)
    
    opt_augstateix_list = []
    
    tmat, init_prob = hmm_params
    init_ind,final_ind,ind = cst_params
    
    T = emit_weights.shape[0]
    K, M = init_ind.shape

    val = np.empty((T,K,M))
    ix_tracker = np.empty((T,K,M)) #will store flattened indices

    #Forward pass
    V = np.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    val[0] = V
    for t in range(1,T):
        V = np.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = V.reshape((K,M,-1))
        max_ix = np.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = np.take_along_axis(V, max_ix, axis=-1).squeeze()
        if t == T:
            val[t] = np.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
        else:
            val[t] = np.einsum('k,kr -> kr', emit_weights[t],V)
        

    #Backward pass

    #Initialize the last index
    max_ix = int(np.argmax(val[T-1]).item())
    max_k, max_r =divmod(max_ix, M)
    opt_augstateix_list = [(max_k,max_r)] + opt_augstateix_list

    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        max_k, max_r = divmod(max_ix, M)
        opt_augstateix_list = [(max_k,max_r)] + opt_augstateix_list

    return opt_augstateix_list


In [129]:
def create_noisy_apt(apt_hmm, mix_param, tol = 1e-7):
    '''
    Original APT: X -> Y

    Tiered APT: (X,Y) -> hat{Y}

    
    '''
    apt = copy.deepcopy(apt_hmm) #deepcopy since there's still some funkiness going on.
    M = len(apt.emits) #number of hidden states
    
    #Creat noisy emissions matrix
    eprob = defaultdict(int)
    #For now, create a noiseless emission, where the emission of the APT is the observed emission
    for k in apt.states:
        for e in apt.emits:
            eprob[k,e] = mix_param*apt.eprob[k,e] + (1- mix_param)/M 
        
    new_apt = Munch(name = apt.name, states = apt.states, emits = apt.emits, tprob = apt.tprob, \
                       eprob = eprob, initprob = apt.initprob)
    
    if apt.mu:
        new_apt.mu = apt.mu

    return new_apt


In [95]:
combined_emits[:10]

[None,
 ('V', 'access/bob'),
 None,
 None,
 None,
 ('S', 'postfix/local'),
 ('HE', 'img/post'),
 None,
 ('DS', 'syslog/ls'),
 None]

In [96]:
apt_hmm.initprob['CA']

0

In [98]:
hmm = apt_hmm
obs = combined_emits[:10]

In [None]:
    val = {} #initialize value dictionary

    for k in hmm.states:
            val[0,k] = hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {} #this is used in the backwards step to identify the optimal sequence
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            max_val = -1 # set to dummy variable. will do brute-force search for max
            argmax = None #initialize argmax for ix_tracker
            for j in hmm.states:
                curr_val = val[t-1,j]*hmm.tprob[j,k]
                if curr_val > max_val:
                    max_val = curr_val
                    argmax = j
                    print(j)
            val[t,k] = max_val*hmm.eprob[k,obs[t]]
            ix_tracker[t-1,k] = argmax


In [100]:
val[0,'Pre']

KeyError: (0, 'Pre')

In [None]:
    val = {} #initialize value dictionary

    for k in hmm.states:
            val[0,k] = hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {} #this is used in the backwards step to identify the optimal sequence
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            max_val = -1 # set to dummy variable. will do brute-force search for max
            argmax = None #initialize argmax for ix_tracker
            for j in hmm.states:
                curr_val = val[t-1,j]*hmm.tprob[j,k]
                if curr_val > max_val:
                    max_val = curr_val
                    argmax = j
            val[t,k] = max_val*hmm.eprob[k,obs[t]]
            ix_tracker[t-1,k] = argmax
    
    #Backward: compute the values of the optimal sequence
    max_val = -1
    best_state = None
    for k in hmm.states:
        curr_val = val[len(obs) - 1,k]
        if curr_val > max_val:
            max_val = curr_val
            best_state = k
    opt_state = [best_state]
    
    for t in range(len(obs) -1):
        best_state = ix_tracker[len(obs) - 2 -t,best_state]
        opt_state = [best_state] + opt_state

    return opt_state


In [None]:
apt_hmm_mix.eprob

In [81]:
apt_truth[0]

['PRE',
 'PRE',
 'PRE',
 'PRE',
 'PRE',
 'IA',
 'EX',
 'WAIT_EX',
 'DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'EXF',
 'WAIT_EXF',
 'WAIT_EXF',
 'WA

In [80]:
opt_state

[('PRE', None),
 ('IA', ('S', 'postfix/local')),
 ('EX', ('S', 'postfix/local')),
 ('DI', ('DS', 'syslog/ls')),
 ('COL', ('HE', 'img/post')),
 ('EXF', ('HE', 'img/query')),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', N

In [None]:
len(tier_apt.hmm)

In [122]:
sat = (True,) * 3
opt_aug, opt_state =  mv.mv_Viterbi_v2(obs = combined_emits, hmm = tier_apt, cst= agg_cst,sat = sat)

1
2
3


KeyboardInterrupt: 

In [None]:
mv_Viterbi(obs, hmm, combined_cst, sat = (True,True))