In [87]:
import numpy as np
import json
from munch import Munch
import itertools
from collections import defaultdict
import random
import copy
import pickle


import apt_helper as ahlp
import apt_cst_aggregate as cagg
import mv_Viterbi as mv

In [193]:
import importlib
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [198]:
names = ['apt','bob','sally']
mu_list = [.8,.9,.9]
apt_hmm, bob_hmm, sally_hmm = ahlp.process_load(names, delay = mu_list)
user_list = [bob_hmm, sally_hmm]

In [235]:
#Check if correctly loaded. probabilities should sum to 1.
for usr in user_list:
    usr_params = ahlp.hmm2numpy(usr)
    print(f'initprob:{usr_params[0].sum()}  tprob: {usr_params[1].sum(axis = 1)}  eprob: {usr_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [236]:
apt_params = ahlp.hmm2numpy(apt_hmm)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


# Sanity Check that base Viterbi works. Pure Emissions

In [312]:
pure_hidden, pure_emission = ahlp.simulation_apt(apt_hmm, ix_list = None, emit_inhom = False)

In [313]:
opt_state =  mv.mv_Viterbi(obs = pure_emission, hmm = apt_hmm)
for t in range(len(opt_state)):
    if opt_state[t] != pure_hidden[t]:
        print(f'mismatch at time {t} out of {len(opt_state) -1}') #no rule enforcing that that must end on POST

mismatch at time 92 out of 92


# Sanity Check. Noiseless Tiered APT Equivalent to Original

In [314]:
tier_apt = ahlp.create_tiered_apt(apt_hmm)
apt_params = ahlp.hmm2numpy(tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [316]:
opt_state_tier =  mv.mv_Viterbi(obs = pure_emission, hmm = tier_apt)
num_correct = 0
for t in range(len(opt_state)):
    # if opt_state[t] != opt_state_tier[t][0]:
    #     print(f'mismatch at time {t} out of {len(opt_state) -1}') #no rule enforcing that that must end on POST
    if opt_state[t] == opt_state_tier[t][0]:
        num_correct += 1 
print(num_correct/len(opt_state))

1.0


# Noisy Simulations

In [317]:
apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)
apt_truth_states, apt_truth_emits = apt_truth

In [318]:
opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
num_correct = 0
for t in range(len(opt_state)):
    if opt_state[t] == apt_truth_states[t]:
        num_correct += 1 #no rule enforcing that that must end on POST
print(num_correct/len(opt_state))

0.9464285714285714


In [319]:
B = 1000
avg_correct = 0
for b in range(B):
    apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)
    apt_truth_states, apt_truth_emits = apt_truth
    opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = apt_hmm)
    num_correct = 0
    for t in range(len(opt_state)):
        if opt_state[t] == apt_truth_states[t]:
            num_correct += 1 #no rule enforcing that that must end on POST
    if b % 50 == 0:
        print(b)
    avg_correct += num_correct/(B*len(opt_state))
print(avg_correct)

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
0.4636889008536915


In [239]:
importlib.reload(cagg)

<module 'apt_cst_aggregate' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_cst_aggregate.py'>

In [202]:
cst_names = ['know_sally_exists','have_sally_credential', 'learn_where_data_stored', 'have_data_on_ds', 'have_data_on_hi', 'have_data_on_he']

cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

cst_list = cst_list[:3]

In [203]:
agg_cst, zip_list, cst_ix = cagg.apt_cst_aggregate(cst_list, debug = True)

In [214]:
tier_apt_mix, ix_list = ahlp.lapt_mixture(tier_apt, user_list, len(combined_emits), mix_weights = None, return_ix = True)

In [215]:
noisy_tier_apt = ahlp.create_noisy_apt(tier_apt, 1/3)

In [216]:
apt_params = ahlp.hmm2numpy(noisy_tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]


In [218]:
#Vanilla Viterbi is no constraint is included
opt_state =  mv.mv_Viterbi(obs = combined_emits, hmm = noisy_tier_apt)

In [129]:
def create_noisy_apt(apt_hmm, mix_param, tol = 1e-7):
    '''
    Original APT: X -> Y

    Tiered APT: (X,Y) -> hat{Y}

    
    '''
    apt = copy.deepcopy(apt_hmm) #deepcopy since there's still some funkiness going on.
    M = len(apt.emits) #number of hidden states
    
    #Creat noisy emissions matrix
    eprob = defaultdict(int)
    #For now, create a noiseless emission, where the emission of the APT is the observed emission
    for k in apt.states:
        for e in apt.emits:
            eprob[k,e] = mix_param*apt.eprob[k,e] + (1- mix_param)/M 
        
    new_apt = Munch(name = apt.name, states = apt.states, emits = apt.emits, tprob = apt.tprob, \
                       eprob = eprob, initprob = apt.initprob)
    
    if apt.mu:
        new_apt.mu = apt.mu

    return new_apt


In [95]:
combined_emits[:10]

[None,
 ('V', 'access/bob'),
 None,
 None,
 None,
 ('S', 'postfix/local'),
 ('HE', 'img/post'),
 None,
 ('DS', 'syslog/ls'),
 None]

In [96]:
apt_hmm.initprob['CA']

0

In [98]:
hmm = apt_hmm
obs = combined_emits[:10]

In [None]:
    val = {} #initialize value dictionary

    for k in hmm.states:
            val[0,k] = hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {} #this is used in the backwards step to identify the optimal sequence
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            max_val = -1 # set to dummy variable. will do brute-force search for max
            argmax = None #initialize argmax for ix_tracker
            for j in hmm.states:
                curr_val = val[t-1,j]*hmm.tprob[j,k]
                if curr_val > max_val:
                    max_val = curr_val
                    argmax = j
                    print(j)
            val[t,k] = max_val*hmm.eprob[k,obs[t]]
            ix_tracker[t-1,k] = argmax


In [100]:
val[0,'Pre']

KeyError: (0, 'Pre')

In [None]:
    val = {} #initialize value dictionary

    for k in hmm.states:
            val[0,k] = hmm.initprob[k]*hmm.eprob[k,obs[0]]
            
    ix_tracker = {} #this is used in the backwards step to identify the optimal sequence
    
    #Forward: compute value function and generate index
    for t in range(1,len(obs)):
        for k in hmm.states:
            max_val = -1 # set to dummy variable. will do brute-force search for max
            argmax = None #initialize argmax for ix_tracker
            for j in hmm.states:
                curr_val = val[t-1,j]*hmm.tprob[j,k]
                if curr_val > max_val:
                    max_val = curr_val
                    argmax = j
            val[t,k] = max_val*hmm.eprob[k,obs[t]]
            ix_tracker[t-1,k] = argmax
    
    #Backward: compute the values of the optimal sequence
    max_val = -1
    best_state = None
    for k in hmm.states:
        curr_val = val[len(obs) - 1,k]
        if curr_val > max_val:
            max_val = curr_val
            best_state = k
    opt_state = [best_state]
    
    for t in range(len(obs) -1):
        best_state = ix_tracker[len(obs) - 2 -t,best_state]
        opt_state = [best_state] + opt_state

    return opt_state


In [None]:
apt_hmm_mix.eprob

In [81]:
apt_truth[0]

['PRE',
 'PRE',
 'PRE',
 'PRE',
 'PRE',
 'IA',
 'EX',
 'WAIT_EX',
 'DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'DI',
 'WAIT_DI',
 'WAIT_DI',
 'WAIT_DI',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'WAIT_COL',
 'EXF',
 'WAIT_EXF',
 'WAIT_EXF',
 'WA

In [80]:
opt_state

[('PRE', None),
 ('IA', ('S', 'postfix/local')),
 ('EX', ('S', 'postfix/local')),
 ('DI', ('DS', 'syslog/ls')),
 ('COL', ('HE', 'img/post')),
 ('EXF', ('HE', 'img/query')),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', None),
 ('POST', N

In [None]:
len(tier_apt.hmm)

In [122]:
sat = (True,) * 3
opt_aug, opt_state =  mv.mv_Viterbi_v2(obs = combined_emits, hmm = tier_apt, cst= agg_cst,sat = sat)

1
2
3


KeyboardInterrupt: 

In [None]:
mv_Viterbi(obs, hmm, combined_cst, sat = (True,True))