In [1]:
import numpy as np
import json
from munch import Munch
import itertools
from collections import defaultdict
import random
import copy
import pickle
import torch
import importlib

import apt_helper as ahlp
import apt_cst_aggregate as cagg
import mv_Viterbi as mv

In [2]:
names = ['apt','bob','sally']
mu_list = [.8,.9,.9]
apt_hmm, bob_hmm, sally_hmm = ahlp.process_load(names, delay = mu_list)
user_list = [bob_hmm, sally_hmm]

In [3]:
#Check if correctly loaded. probabilities should sum to 1.
for usr in user_list:
    usr_params = ahlp.hmm2numpy(usr)
    print(f'initprob:{usr_params[0].sum()}  tprob: {usr_params[1].sum(axis = 1)}  eprob: {usr_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [4]:
apt_params = ahlp.hmm2numpy(apt_hmm)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [5]:
cst_names = ['know_sally_exists','have_sally_credential', 'learn_where_data_stored', 'have_data_on_ds', 'have_data_on_hi', 'have_data_on_he']
cst_names = [names + '_TRUE' for names in cst_names]
cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

# cst_list = cst_list[:4]
sat = len(cst_list) * (True,)
agg_cst = cagg.apt_cst_aggregate(cst_list)

In [6]:
agg_cst = cagg.apt_cst_aggregate(cst_list)
tier_apt = ahlp.create_tiered_apt(apt_hmm)
apt_params = ahlp.hmm2numpy(tier_apt)
print(f'initprob:{apt_params[0].sum()}  tprob: {apt_params[1].sum(axis = 1)}  eprob: {apt_params[2].sum(axis = 1)}')

initprob:1.0  tprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]  eprob: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]


In [9]:
cst_names = ['dummy_constraint']
cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)
sat = len(cst_list) * (True,)

In [10]:
cst_names = ['know_sally_exists','have_sally_credential', 'learn_where_data_stored', 'have_data_on_ds', 'have_data_on_hi', 'have_data_on_he']
cst_names = [names + '_TRUE' for names in cst_names]
cst_list=  []
for name in cst_names:
    module = importlib.import_module(name)
    
    curr_cst =  Munch(name = module.name, \
                      aux_size = module.aug_size, \
                      update_fun = module.update_fun, \
                      init_fun = module.init_fun, \
                      forbidden_emissions = module.forbidden_emissions, \
                      forbidden_transitions = module.forbidden_transitions, \
                      knowledge_state = module.knowledge_state, \
                      cst_fun = module.cst_fun)
    if hasattr(module, 'dependency'):
        curr_cst.dependency = module.dependency
    cst_list.append(curr_cst)

# cst_list = cst_list[:4]
sat = len(cst_list) * (True,)
agg_cst = cagg.apt_cst_aggregate(cst_list)

In [11]:
apt_params, ix_list = ahlp.hmm2numpy(tier_apt, return_ix = True)
apt_hidden, apt_emits = ahlp.simulation_apt(apt_hmm)
apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)

In [12]:
pure_hidden, pure_emission = ahlp.simulation_apt(apt_hmm, ix_list = None, emit_inhom = False)
emit_weights = ahlp.compute_emitweights(pure_emission,tier_apt, time_hom = True)
hmm_params = [apt_params[1], apt_params[0]]
opt_list = mv.Viterbi_numpy(hmm_params, emit_weights)
state_ix, _ = ix_list
state_ix = {v:k for k,v in state_ix.items()}
numpy_list = [state_ix[i] for i in opt_list]

In [13]:
importlib.reload(ahlp)
importlib.reload(cagg)

<module 'apt_cst_aggregate' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_cst_aggregate.py'>

In [14]:
emit_weights = ahlp.compute_emitweights(pure_emission,tier_apt, time_hom = True)
# hmm_params, cst_params = ahlp.arrayConvert(tier_apt, agg_cst, sat)
hmm_params, cst_params = ahlp.arrayConvert(tier_apt, agg_cst, sat)

In [297]:
emit_weights = ahlp.compute_emitweights(combined_emits,tier_apt, time_hom = True)

In [298]:
opt_cst_list = mv_Viterbi_numpy(hmm_params, emit_weights, cst_params)
numpy_cst_list = [state_ix[state[0]] for state in opt_cst_list]

In [170]:
num_correct = 0
for t in range(len(numpy_list)):
    if numpy_list[t] == numpy_cst_list[t]:
        num_correct += 1
print(f'proportion correct: {num_correct/len(numpy_list)}')

proportion correct: 1.0


In [171]:
# opt_list = ahlp.Viterbi_torch_list(tier_apt, cst_list, combined_emits, sat, time_hom = True, device = 'cuda:0')

In [338]:
tier_apt = ahlp.create_tiered_apt(apt_hmm)
print(len(tier_apt.states))

141


In [334]:
obs = combined_emits
device = 'cuda:0'
hmm = tier_apt
time_hom = True
dtype = torch.float32

In [335]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [336]:
    #Generate emit_weights:
    emit_weights = ahlp.compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list, state_ix = ahlp.convertTensor_list(hmm,cst_list, sat, dtype = dtype, \
                                                               device = device, return_ix = True)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    js_indices = [k + C + 1 for k in kr_indices]


In [327]:
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max() #normalize for numerical stability
    val[0] = V.cpu()


In [102]:
t = 1
V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
V = V.reshape((K,) + tuple(dims_list) + (-1,))
V = V/V.max()


In [103]:
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
        V = V.reshape((K,) + tuple(dims_list) + (-1,))
        # V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()


In [84]:
max_ix.sum()

tensor(0, device='cuda:0')

In [85]:
    emit_weights = compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list = convertTensor_list(hmm,cst_list, sat, device = device)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    #Forward pass
    # V = torch.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max() #normalize for numerical stability
    val[0] = V.cpu()
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
        V = V.reshape((K,) + tuple(dims_list) + (-1,))
        V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()
        

    #Backward pass
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list.append(np.array(unravel_max_ix).tolist())
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list.append(np.array(unravel_max_ix).tolist())


NameError: name 'compute_emitweights' is not defined

In [None]:
def Viterbi_torch_list(hmm, cst_list, obs, sat, time_hom = True, device = 'cpu'):
    '''
    
    '''
    #Generate emit_weights:
    emit_weights = compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list = convertTensor_list(hmm,cst_list, sat, device = device)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    #Forward pass
    # V = torch.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max() #normalize for numerical stability
    val[0] = V.cpu()
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
        V = V.reshape((K,) + tuple(dims_list) + (-1,))
        V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()
        

    #Backward pass
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list.append(np.array(unravel_max_ix).tolist())
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list.append(np.array(unravel_max_ix).tolist())

    return opt_augstateix_list


In [33]:
tier_apt_mix, ix_list = ahlp.lapt_mixture(tier_apt, user_list, len(combined_emits), mix_weights = None, return_ix = True)

In [34]:
device = 'cuda:0'

In [35]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [19]:
for cst in cst_list:
    print(cst.forbidden_transitions)

[('EX', 'CA'), ('WAIT_EX', 'CA'), ('DI', 'CA'), ('WAIT_DI', 'CA')]
[]
[('DI', 'COL'), ('WAIT_DI', 'COL')]
[]
[]
[('COL', 'EXF'), ('WAIT_COL', 'EXF')]


In [118]:
x_curr @ emat_curr

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [20]:
#Generate rest
while x_state != 'POST':
    x_curr = random_draw(x_prev @ tmat_curr)
    if emit_inhom:
        y_curr = random_draw(x_curr @ emat_curr[t])
    else:
        y_curr = random_draw(x_curr @ emat_curr)
    x_state = state_ix[np.argmax(x_curr)]
    y_state = emit_ix[np.argmax(y_curr)]
    hid_emit = (x_state,y_state) 
    if hid_emit in notyet_knowledge:
        tmat_mask_dict.pop(hid_emit)
        eprob_mask_dict.pop(hid_emit)
        tmat_curr = tmat * np.prod(list(tmat_mask_dict.values()), axis = 0)
        emat_curr = emat * np.prod(list(eprob_mask_dict.values()), axis = 0)
        notyet_knowledge = list(tmat_mask_dict.keys())
        
    x_list.append(x_state)
    y_list.append(y_state)
    x_prev = x_curr


In [15]:
    def random_draw(p):
        '''
        p is a 1D np array. 
        single random draw from probability vector p and encode as 1-hot.
        '''
        n = len(p)
        if p.sum() <= 0:
            print('Error')
        p = p/p.sum()
        draw = np.random.choice(n,p=p)
        one_hot = np.zeros(n, dtype = int)
        one_hot[draw] = 1
        return one_hot


In [45]:
def simulation_knowledge(hmm, cst_list, ix_list = None, emit_inhom = False):
    '''
    for the apt, generates a run that stops whenever the "POST" state is encountered.
    '''
    #Get numpy version of hmm parameters
    hmm_params, ix_list = hmm2numpy(hmm, ix_list = ix_list, return_ix = True, emit_inhom = emit_inhom) 
    init_prob, tmat, emat = hmm_params
    
    #Create dictionaries for generating mask for transitions/emissions
    state_ix, emit_ix = ix_list
    K, M = len(state_ix), len(emit_ix)
    
    tmat_mask_dict = {}
    eprob_mask_dict = {}
    for cst in cst_list:
        t_mask = np.ones((K,K))
        e_mask = np.ones((K,M))
        for ft in cst.forbidden_transitions:
            t_mask[state_ix[ft[0]],state_ix[ft[1]]] = 0
        for fe in cst.forbidden_emissions:
            e_mask[state_ix[fe[0]],emit_ix[fe[1]]] = 0
        tmat_mask_dict[cst.knowledge_state] = t_mask
        eprob_mask_dict[cst.knowledge_state] = e_mask
    
    state_ix = {v:k for k,v in state_ix.items()}
    emit_ix = {v:k for k,v in emit_ix.items()}
    
    notyet_knowledge = list(tmat_mask_dict.keys())  
    
    tmat_curr = tmat * np.prod(list(tmat_mask_dict.values()), axis = 0)
    emat_curr = emat * np.prod(list(eprob_mask_dict.values()), axis = 0)
    
    x_prev = random_draw(init_prob)
    x_state = state_ix[np.argmax(x_prev)] #convert one-hot back to state
    x_list = [x_state] 
    if emit_inhom:
        y_curr = random_draw(x_prev @ emat_curr[0])
    else:
        y_curr = random_draw(x_prev @ emat_curr)
    y_state = emit_ix[np.argmax(y_curr)]
    y_list = [y_state]

    #Generate rest
    while x_state != 'POST':
        x_curr = random_draw(x_prev @ tmat_curr)
        if emit_inhom:
            y_curr = random_draw(x_curr @ emat_curr[t])
        else:
            y_curr = random_draw(x_curr @ emat_curr)
        x_state = state_ix[np.argmax(x_curr)]
        y_state = emit_ix[np.argmax(y_curr)]
        hid_emit = (x_state,y_state) 
        if hid_emit in notyet_knowledge: #if knowledge state, gets rid of it from the mask
            tmat_mask_dict.pop(hid_emit)
            eprob_mask_dict.pop(hid_emit)
            tmat_curr = tmat * np.prod(list(tmat_mask_dict.values()), axis = 0)
            emat_curr = emat * np.prod(list(eprob_mask_dict.values()), axis = 0)
            notyet_knowledge = list(tmat_mask_dict.keys())
            
        x_list.append(x_state)
        y_list.append(y_state)
        x_prev = x_curr

    return x_list, y_list


In [15]:
opt_augix_list = ahlp.Viterbi_torch_list(tier_apt_mix, cst_list, combined_emits, sat, device = 'cuda:0')

In [20]:
obs = combined_emits
hmm = tier_apt_mix
time_hom = False

emit_weights = ahlp.compute_emitweights(obs, hmm, time_hom)
emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

#Generate hmm,cst params:
hmm_params, cst_params_list = ahlp.convertTensor_list(hmm,cst_list, sat, device = device)   
tmat, init_prob = hmm_params
dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list


#Viterbi
T = emit_weights.shape[0]
K = tmat.shape[0]
C = len(dims_list)

val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices

kr_indices = list(range(C+1))
kr_shape = (K,) + tuple(dims_list)
#Forward pass
# V = torch.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
V = V/V.max() #normalize for numerical stability
val[0] = V.cpu()
for t in range(1,T):
    # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
    V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
    V = V.reshape((K,) + tuple(dims_list) + (-1,))
    V = V/V.max()
    max_ix = torch.argmax(V, axis = -1, keepdims = True)
    ix_tracker[t-1] = max_ix.squeeze()
    V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
    if t == T:
        # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
        val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
    else:
        # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
        val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()


In [24]:
val[92].max()

tensor(0.6074)

In [95]:
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max()
    val[0] = V.cpu()

In [27]:
max_ix.shape

torch.Size([25, 4, 4, 4, 4, 4, 1])

In [None]:
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list.append(np.array(unravel_max_ix).tolist())
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list.append(np.array(unravel_max_ix).tolist())


In [110]:
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        intermediate = torch.einsum(val[t-1].to(device),kr_indices, *ind_list, list(range(2*C + 2)))
        V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
        V = V.reshape((K,) + tuple(dims_list) + (-1,))
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()


OutOfMemoryError: CUDA out of memory. Tried to allocate 39.06 GiB. GPU 0 has a total capacity of 39.49 GiB of which 19.47 GiB is free. Including non-PyTorch memory, this process has 20.01 GiB memory in use. Of the allocated memory 237.14 MiB is allocated by PyTorch, and 19.30 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [80]:
def convertTensor_list(hmm, cst_list, sat, device):
    '''
    cst_list is a list of the individual csts.
    '''
    #Initialize and convert all quantities  to np.arrays
    K = len(hmm.states)
    assert len(cst_list) == len(sat)
    
    state_ix = {s: i for i, s in enumerate(hmm.states)}

    #Compute the hmm parameters
    tmat = torch.zeros((K,K), dtype=torch.float16 ).to(device)
    init_prob = torch.zeros(K, dtype=torch.float16 ).to(device)

    for i in hmm.states:
        init_prob[state_ix[i]] = hmm.initprob[i]
        for j in hmm.states:
            tmat[state_ix[i],state_ix[j]] = hmm.tprob[i,j]

    hmm_params = [tmat, init_prob]
    
    #Compute the cst parameters 
    init_ind_list = []
    final_ind_list = []
    ind_list = []
    dims_list = []
    cst_ix = 0
    C = len(cst_list)
    for cst in cst_list:
        aux_space = list(itertools.product([True, False], repeat=cst.aux_size))
        aux_ix = {s: i for i, s in enumerate(aux_space)}
        M = len(aux_space)
        ind = torch.zeros((K,M,K,M),dtype=torch.float16 ).to(device)
        init_ind = torch.zeros((K,M),dtype=torch.float16 ).to(device)
        final_ind = torch.zeros((K,M),dtype=torch.float16 ).to(device)
    
        for r in aux_space:
            for k in hmm.states:
                final_ind[state_ix[k], aux_ix[r]] = cst.cst_fun(k,r,sat)
                init_ind[state_ix[k],aux_ix[r]] = cst.init_fun(k,r)
                for s in aux_space:
                    for j in hmm.states:
                        ind[state_ix[k],aux_ix[r],state_ix[j],aux_ix[s]] = cst.update_fun(k,r,j,s)

        #indices are [0 = k,  (1 dim for each cst r_i = i + 1)  0 <= i <= n - 1 
        # init_ind_list.append((init_ind,[0,cst_ix + 1]))
        # final_ind_list.append((final_ind, [0, cst_ix + 1]))
        # #indices are [0 = k,(1 dim for each cst r_i = i + 1), n + 1 = j, (1 dim for s_i = i+n+2)] 
        # ind_list.append((ind, [0, cst_ix + 1, C + 1, cst_ix + C + 2]))
        # dims_list.append(M)

        init_ind_list += [init_ind,[0,cst_ix + 1]]
        final_ind_list += [final_ind, [0, cst_ix + 1]]
        #indices are kjrs instead of krjs for easier indexing with einsum. 
        ind_list += [ind, [0, 1, 2*cst_ix + 2, 2*cst_ix + 3]]
        dims_list.append(M)
        cst_ix += 1
                
    cst_params = [dims_list, init_ind_list,final_ind_list,ind_list]
    
    return hmm_params, cst_params 


In [17]:
def Viterbi_torch_list(hmm, cst_list, obs, sat, time_hom = True, device = 'cpu'):
    '''
    
    '''
    #Generate emit_weights:
    emit_weights = compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(torch.float16).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list = convertTensor_list(hmm,cst_list, sat, device = device)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    #Forward pass
    # V = torch.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/V.max() #normalize for numerical stability
    val[0] = V.cpu()
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), kr_indices, tmat, [0,C+1], *ind_list, list(range(2*C + 2)))
        V = V.reshape(tuple(kr_indices + [-1]))
        V = V/V.max()
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()
        

    #Backward pass
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list.append(np.array(unravel_max_ix).tolist())
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list.append(np.array(unravel_max_ix).tolist())

    return opt_augstateix_list


In [23]:
opt_augix_list = ahlp.Viterbi_torch_list(tier_apt_mix, cst_list, combined_emits, sat, device = 'cuda:0')

In [24]:
ix_list

[{('PRE', None): 0,
  ('IA', ('S', 'postfix/local')): 1,
  ('EX', ('V', 'access/bob')): 2,
  ('EX', ('V', 'access/sally')): 3,
  ('EX', ('S', 'postfix/local')): 4,
  ('EX', ('HI', 'img/post')): 5,
  ('EX', ('HE', 'img/post')): 6,
  ('EX', ('DS', 'syslog/nano')): 7,
  ('DI', ('S', 'postfix/local')): 8,
  ('DI', ('HI', 'usr/query')): 9,
  ('DI', ('HI', 'img/query')): 10,
  ('DI', ('HE', 'img/query')): 11,
  ('DI', ('DS', 'syslog/ls')): 12,
  ('CA', ('HI', 'usr/query')): 13,
  ('COL', ('HI', 'img/post')): 14,
  ('COL', ('HE', 'img/post')): 15,
  ('COL', ('DS', 'syslog/nano')): 16,
  ('EXF', ('HE', 'img/query')): 17,
  ('POST', None): 18,
  ('WAIT_DI', None): 19,
  ('WAIT_COL', None): 20,
  ('WAIT_EX', None): 21,
  ('WAIT_CA', None): 22,
  ('WAIT_IA', None): 23,
  ('WAIT_EXF', None): 24},
 {('DS', 'syslog/ls'): 0,
  ('S', 'postfix/local'): 1,
  ('DS', 'syslog/nano'): 2,
  ('HE', 'img/post'): 3,
  ('V', 'access/bob'): 4,
  ('HE', 'img/query'): 5,
  None: 6,
  ('HI', 'img/post'): 7,
  ('HI',

In [58]:
importlib.reload(ahlp)

<module 'apt_helper' from '/home/fyqiu/Projects/conin/conin/mediation_variables/apt_helper.py'>

In [59]:
apt_truth, combined_emits = ahlp.combined_simulation(apt_hmm, user_list)

In [245]:
test_r = (True,) *agg_cst.aux_size
hmm_params, cst_params = ahlp.arrayConvert(tier_apt,agg_cst, sat = test_r)

KeyboardInterrupt: 

In [45]:
emit_weights = compute_emitweights(combined_emits, tier_apt)

In [193]:
K = 25
M = 2**12
T = 20

tmat = np.random.rand(K,K)
tmat = tmat - tmat.min()
tmat = tmat/tmat.sum(axis = -1, keepdims = True)

init_prob = np.random.rand(K)
init_prob = init_prob - init_prob.min()
init_prob = init_prob/init_prob.sum()

emit_weights = np.random.rand(T,K)
emit_weights = emit_weights - emit_weights.min()
emit_weights = emit_weights/emit_weights.max()

hmm_params = [tmat, init_prob]


init_ind = np.random.binomial(1,.01,(K,M))
final_ind = np.random.binomial(1,.01,(K,M))
ind = np.random.binomial(1,.005,(K,M,K,M))

KeyboardInterrupt: 

In [191]:
ind.sum()/ind.flatten().shape

array([0.00499908])

In [176]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_tensor_type(torch.cuda.FloatTensor)

In [177]:
 test_list = mv_Viterbi_numpy(hmm_params, cst_params, emit_weights)

1
2


KeyboardInterrupt: 

In [189]:
def numpy2tensor(hmm_params, cst_params, emit_weights, device):
    '''
    Converts all the numpy arrays to torch tensors
    '''
    hmm_params_torch = [torch.from_numpy(array).to(device) for array in hmm_params]
    cst_params_torch = [torch.from_numpy(array).to(device) for array in cst_params]
    emit_weights_torch = torch.from_numpy(emit_weights).to(device)

    return hmm_params_torch, emit_weights_torch, emit_weights_torch

In [186]:
test_list = mv_Viterbi_torch(hmm_params_torch, cst_params_torch, emit_weights_torch)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
