In [39]:
# import numpy as np
# import torch
# import warnings
# from munch import Munch
# import itertools
# from mv_Viterbi import mv_Viterbi
# from cst_aggregate import cst_aggregate

import numpy as np
import torch
import json
from munch import Munch
import itertools
from collections import defaultdict
import random
import copy
import pickle
import matplotlib.pyplot as plt
import importlib
import time


### Create the HMM

In [12]:
hidden_states = ['pro','ex1','ex2','int','dis','enh']
emit_states = ['A','T', 'C', 'G']
hidden_size, emit_size = len(hidden_states), len(emit_states)

hmm_mat = np.array([
    [.6,.1,.1,.1,.1,0], #promoter
    [0,.4,.2,.2,.1,.1], #exon1
    [.0,.1,.6,.1,.1,.1], #exon2
    [.2,.1,.1,.5,0,.1], #intron
    [0,1/3, 1/3, 0,1/3,0], #disease
    [0,.25,.25,.25,0,.25] #enhancer
])

emit_mat = np.array([ #
    [.1,.1,.4,.4], #CG rich promoter
    [.2,.2,.5,.1], #Exon 1 favors C
    [.5,.1,.2,.2], #Exon 2 favors A
    [.25,.25,.25,.25], #Intron 
    [.4,.1,.4,.1], #Disease favors AC
    [.4,.4,.1,.1] #AT rich enhancer
])

init_vec = np.array(
    [.2,0,0,.8,0,0]
)

hmm_transition = {}
for i in range(hidden_size):
    for j in range(hidden_size):
        hmm_transition[hidden_states[i],hidden_states[j]] = hmm_mat[i,j].item()

hmm_emit = {}
for i in range(hidden_size):
    for j in range(emit_size):
        hmm_emit[hidden_states[i],emit_states[j]] = emit_mat[i,j].item()
        
hmm_startprob = {}
for i in range(hidden_size):
    hmm_startprob[hidden_states[i]] = init_vec[i]

hmm = Munch(states = hidden_states, emits = emit_states, tprob = hmm_transition, eprob = hmm_emit, initprob = hmm_startprob)

### Stay > = 5

In [23]:
def create_cst_params(cst, hidden_states, dtype = torch.float16, device = 'cpu'):
    m_states = cst.m_states
    init = cst.init_fun
    upd = cst.update_fun
    eval = cst.eval_fun

    #returns a (k,s,r) array. k is current hideen. r,s are present/past mediation.
    upd_mat = torch.tensor([[[upd(k,s, r) for r in m_states] for s in m_states] for k in hidden_states], dtype = dtype, device = device)

    #returns a (k,r) array. k,r are current hidden/mediation states
    init_mat = torch.tensor([[init(k,r) for r in m_states] for k in hidden_states], dtype = dtype, device = device)

    #return (k,r) array for terminal emission.
    eval_mat = torch.tensor([[eval(k,r) for r in m_states] for k in hidden_states], dtype = dtype, device = device)

    return init_mat, upd_mat, eval_mat


In [24]:
def update_fun(k , r_past, r):
    '''
    r = hidden_states x [1,2,3,4,5]
    '''
    prev, count = r_past #r is a tuple
    if k == prev:
        new_count = count + 1
    else:
        new_count = 1
        
    consistency = (count == 5) or (k == prev) #0 if transition to new state without staying 3

    return (r == (k,new_count)) and consistency

def init_fun(k, r):
    '''
    initial "prob" of r = (m1,m2) from k. is just indicator
    '''

    return r == (k,1)

    
def eval_fun(k, r):
    return True

m_states = list(itertools.product(hidden_states, list(range(1,6))))

stay_cst = Munch(update_fun = update_fun, init_fun = init_fun, eval_fun = eval_fun, m_states = m_states)

In [25]:
update_fun('enh',('pro',4), ('enh',1))

False

In [26]:
a,b,c = create_cst_params(stay_cst, hidden_states)

### Promoter Must Occur in First 30

In [52]:
def update_fun(k , r_past, r):
    '''
    r = Boolean
    tracks if 'pro' has occured yet or not
    '''        

    return r == (r_past or (k == 'pro')) 

def init_fun(k, r):

    return r == (k == 'pro')

def eval_fun(k,r):
    return r == True

m_states = [True,False]

promoter_cst = Munch(update_fun = update_fun, init_fun = init_fun, eval_fun = eval_fun, m_states = m_states)

In [53]:
a,b,c = create_cst_params(promoter_cst, hidden_states)

#### Visit Dis Exactly Once 

In [36]:
def update_fun(k , r_past, r):
    '''
    r = [0,1,2]
    tracks if 'pro' has occured yet or not
    '''
    if k == 'dis':
        count = max(r + 1, 2)
        
    else:
        count = r_past
    
    return r == count 

def init_fun(k, r):

    return r == int(k == 'dis')

def eval_fun(r, sat):
    return r == 1 #must be exactly 1.

m_states = list(range(3))

disvisit_cst = Munch(update_fun = update_fun, init_fun = init_fun, eval_fun = eval_fun, m_states = m_states)

In [54]:
a,b,c = create_cst_params(disvisit_cst, hidden_states)

### Promoter < Disease < Enhancer

In [7]:
def update_fun(k , r_past, r):
    '''
    r = Boolean_pro x Bool_dis x Bool_enh
    trcks that they occur in sequence
    '''
    occur_pro, occur_dis, occur_enh = r_past
    consist = True
    
    pro_new = (k == 'pro' or occur_pro)
    dis_new = (k == 'dis' or occur_dis)
    enh_new = (k == 'pro' or occur_enh)

    if k == 'dis':
        consist = occur_pro

    if k == 'enh':
        consist = occur_dis 

    return (r == (pro_new, dis_new,enh_new)) and consist

def init_fun(k, r):

    return r == ( k == 'pro', k == 'dis', k == 'enh')

def eval_fun(k,r):
    return True

m_states = list(itertools.product([True, False], repeat=3))

pde_cst = Munch(update_fun = update_fun, init_fun = init_fun, eval_fun = eval_fun, m_states = m_states)



In [8]:
a,b,c = create_cst_params(pde_cst, hidden_states)

NameError: name 'hidden_states' is not defined

In [37]:
cst_list = [disvisit_cst, pde_cst]

In [50]:
test_params, testcst_params = convertTensor_list(hmm, cst_list, dtype = torch.float32, device = 'cpu', hmm_params = None, return_ix = False)

In [5]:
def create_cst_params(cst, hidden_states, dtype = torch.float16, device = 'cpu'):
    m_states = cst.m_states
    init = cst.init_fun
    upd = cst.update_fun
    eval = cst.eval_fun


In [43]:
def create_cst_params(cst, hidden_states, dtype = torch.float16, device = 'cpu'):
    m_states = cst.m_states
    init = cst.init_fun
    upd = cst.update_fun
    eval = cst.eval_fun

    #returns a (k,s,r) array. k is current hideen. r,s are present/past mediation.
    upd_mat = torch.tensor([[[upd(k,s, r) for r in m_states] for s in m_states] for k in hidden_states], dtype = dtype, device = device)

    #returns a (k,r) array. k,r are current hidden/mediation states
    init_mat = torch.tensor([[init(k,r) for r in m_states] for k in hidden_states], dtype = dtype, device = device)

    #return (k,r) array for terminal emission.
    eval_mat = torch.tensor([[eval(k,r) for r in m_states] for k in hidden_states], dtype = dtype, device = device)

    return init_mat, upd_mat, eval_mat

def convertTensor_list(hmm, cst_list, dtype = torch.float16, device = 'cpu', hmm_params = None, return_ix = False):
    '''
    cst_list is a list of the individual csts.
    '''
    #Initialize and convert all quantities  to np.arrays
    hmm = copy.deepcopy(hmm)
    K = len(hmm.states)
    
    state_ix = {s: i for i, s in enumerate(hmm.states)}
    
    #Compute the hmm parameters if not provided
    if hmm_params is None:
        tmat = torch.zeros((K,K), dtype=dtype ).to(device)
        init_prob = torch.zeros(K, dtype=dtype ).to(device)
    
        for i in hmm.states:
            init_prob[state_ix[i]] = hmm.initprob[i]
            for j in hmm.states:
                tmat[state_ix[i],state_ix[j]] = hmm.tprob[i,j]
    
        hmm_params = [tmat, init_prob]
    
    #Compute the cst parameters 
    init_list = []
    eval_list = []
    upd_list = []
    dims_list = []
    cst_ix = 0
    C = len(cst_list)

    #indices are (hidden, c_1,....,c_C, hidden, c_1,....,c_C) are augmented messages
    for cst in cst_list:
        cst = copy.deepcopy(cst)
        init_mat, upd_mat, eval_mat = create_cst_params(cst, hidden_states, dtype = dtype, device = device)
        init_list += [init_mat,[0,cst_ix + 1]]
        eval_list += [eval_mat, [0, cst_ix + 1]]
        upd_list += [upd_mat, [0, cst_ix + 1, C + 1, cst_ix + C + 2]]
        dims_list.append(len(cst.m_states))
        cst_ix += 1
                
    cst_params = [dims_list, init_list,upd_list,upd_list]

    if return_ix:
        return hmm_params, cst_params, state_ix
    return hmm_params, cst_params 


In [32]:
def compute_emitweights(obs,hmm, time_hom = True):
    '''
    Separately handles the computation of the 
    '''
    hmm = copy.deepcopy(hmm) #protect again in place modification
    T = len(obs)
    K = len(hmm.states)
    #Compute emissions weights for easier access
    emit_weights = np.zeros((T,K))
    for t in range(T):
        if time_hom:
            emit_weights[t] = np.array([hmm.eprob[k,obs[t]] for k in hmm.states])
        else:
            emit_weights[t] = np.array([hmm.eprob[t,k,obs[t]] for k in hmm.states])
    return emit_weights


In [None]:
def Viterbi_torch_list(hmm, hmm_params, cst_list, obs, dtype = torch.float16,  device = 'cpu', debug = False, num_cst = 0):
    '''
    more optimized torch implementation of Viterbi. The constraint all evolve independently (ie. factorial), so no need to create a big U_krjs matrix. Instead, just multiply along given dim. Still require computing V_{krjs}, but this should help.
    For numerica underflow, we normalize the value at each time. Also, we add a small constant num_cst when normalizing.
    '''
    hmm = copy.deepcopy(hmm) #protect again in place modification
    #Generate emit_weights:
    emit_weights = compute_emitweights(obs, hmm, time_hom)
    emit_weights = torch.from_numpy(emit_weights).type(dtype).to(device)

    #Generate hmm,cst params:
    hmm_params, cst_params_list, state_ix = convertTensor_list(hmm,cst_list, sat, dtype = dtype, \
                                                               device = device, return_ix = True)   
    tmat, init_prob = hmm_params
    dims_list, init_ind_list,final_ind_list,ind_list = cst_params_list

    
    #Viterbi
    T = emit_weights.shape[0]
    K = tmat.shape[0]
    C = len(dims_list)
    
    val = torch.empty((T,K) + tuple(dims_list), device = 'cpu')
    ix_tracker = torch.empty((T,K) + tuple(dims_list), device = 'cpu') #will store flattened indices
    
    kr_indices = list(range(C+1))
    kr_shape = (K,) + tuple(dims_list)
    js_indices = [k + C + 1 for k in kr_indices]

    #Forward pass
    # V = torch.einsum('k,k,kr -> kr', init_prob, emit_weights[0], init_ind)
    V = torch.einsum(emit_weights[0], [0], init_prob, [0], *init_ind_list, kr_indices)
    V = V/(V.max() + num_cst) #normalize for numerical stability
    val[0] = V.cpu()
    for t in range(1,T):
        # V = torch.einsum('js,jk,krjs -> krjs',val[t-1],tmat,ind)
        V = torch.einsum(val[t-1].to(device), js_indices, tmat, [C+1,0], *ind_list, list(range(2*C + 2)))
        V = V.reshape(tuple(kr_shape) + (-1,))
        V = V/(V.max() + num_cst)
        max_ix = torch.argmax(V, axis = -1, keepdims = True)
        ix_tracker[t-1] = max_ix.squeeze()
        V = torch.take_along_dim(V, max_ix, axis=-1).squeeze()
        if t == T:
            # val[t] = torch.einsum('k,kr,kr -> kr',emit_weights[t],final_ind,V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices,*final_ind_list, kr_indices).cpu()
        else:
            # val[t] = torch.einsum('k,kr -> kr', emit_weights[t],V)
            val[t] = torch.einsum(emit_weights[t],[0], V, kr_indices, kr_indices).cpu()
        
    state_ix = {v:k for k,v in state_ix.items()}
    #Backward pass
    opt_augstateix_list = []
    max_ix = int(torch.argmax(val[T-1]).item())
    unravel_max_ix = np.unravel_index(max_ix, kr_shape)
    opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list
    
    ix_tracker = ix_tracker.reshape(T,-1) #flatten again for easier indexing    
    
    for t in range(T-1):
        max_ix =  int(ix_tracker[T-2-t,max_ix].item())
        unravel_max_ix = np.unravel_index(max_ix, kr_shape)
        opt_augstateix_list =  [np.array(unravel_max_ix).tolist()] + opt_augstateix_list

    opt_state_list = [state_ix[k[0]] for k in opt_augstateix_list]
    if debug:
        return opt_state_list, opt_augstateix_list, val, ix_tracker
    return opt_state_list, opt_augstateix_list


### Inference when the Constraint is Satisfied

Here, we constrain $C=1$: $a$ must happen before $c$. As predicted, when encountering an initial sequence of $C$'s, our model choose $b$ since $c$ is not allowed and $b$ has a higher chance of emitting $C$. Provided the initial number of $C$'s is at most 2, we'll see this behavior. We can increase the admissable length of $b$'s by decreasing the emission probabilities $a,A$ and $c,C$ if we want.

In [13]:
obs = ['C','A','C','A','C']

In [14]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, precedence_cst, sat = True)

In [15]:
opt_state

['b', 'a', 'c', 'a', 'c']

### Inference when the COnstraint is NOT Satsified

Now, we observe $C= 0$: that the constrain is not satisifed. It's logical negation is just that $c$ happens before $a$, and the inferene situation is symmetric. We see that encountering a small initial sequence of $A$'s makes us choose $b$ for the same reasons as above.

In [8]:
obs = ['A','A','C','A','C','A','C']

In [9]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, precedence_cst, sat = False)

In [10]:
opt_state

['b', 'b', 'c', 'a', 'c', 'a', 'c']

# Occurence Constraint

Now, we create anothe constraint class that enforce that state $b$ must be visited at some point. This is equivalent to replacing just one of $a$ or $c$ in the unconstrained MAP with $b$, at any time point.

In [11]:
def update_fun2(r,k , r_past):
    '''
    m1 = = tau_b or b . tracks if b has occured
    '''
    m1 = (k == 'b') or r_past[0]

    return int(r == (m1,))

def init_fun2(k, r):
    '''
    initial "prob" of r = m1,m2 from k. is just indicator
    '''
    m1 = k == 'b'

    return int(r == (m1,))
    
def cst_fun2(r, sat):
    '''
    Constraint is a boolean emissions of the final auxillary state. In this case
    '''
    
    return int(r[0]  == sat) 

In [12]:
occurence_cst = Munch(name = 'b must occur', aux_size = 1, update_fun = update_fun2, init_fun = init_fun2, cst_fun = cst_fun2)

In [13]:
obs = ['C','C','A','C','A','C']

In [14]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, occurence_cst, sat = True)

In [15]:
opt_state

['b', 'c', 'a', 'c', 'a', 'c']

## Occurent Constraint is False

If we condition on the constraint being false, this is equivalent to "$b$ is never visited". Since unconstrained inference will never return $b$, setting the constriant to be False will give the same answer as unconstrained inference.

In [17]:
obs = ['C','C','A','C','A','C']

In [18]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, occurence_cst, sat = False)
opt_state

['c', 'c', 'a', 'c', 'a', 'c']

# Conditioning on Multiple Constraints and Their Values

Now, we'll introduce both the precendence constraint "$a$ happens before $c$" and "$b$ must happen at some point" into our model. Again, these are modeled as binary emissions, so we can play with their truth configurations.

In [34]:
cst_list = [precedence_cst,occurence_cst]
combined_cst = cst_aggregate(cst_list)
combined_cst.name

['a occurs before c', 'b must occur']

# Both True

First, we assume both constraints are true. Note that the below observation sequence is chosen so that the precendence constraint already makes $b$ appear first, so the occurence constraint is satsified automatically. Therefore, the answer should be the same as just conditioning on the precendence constraint

In [25]:
obs = ['C','C','A','C','A','C']

In [26]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, combined_cst, sat = (True,True))

In [27]:
opt_state

['b', 'b', 'a', 'c', 'a', 'c']

### Precendence True, Occurence False

Now here's an interesting scenario. The occurence constraint being unsatisfied is equivalent to $b$ never occuring. Now, when the precendence constraint kicks in, we can only choose $a$ or $c$. This means that any initial sequence of $C$ emissions is forced to return $a$, as opposed to $b$ if we were just enforcing the precendence constraint by itself.

In [32]:
obs = ['C','C','A','C','A','C']

In [33]:
opt_aug, opt_state = mv_Viterbi(obs, hmm, combined_cst, sat = (True,False))
opt_state

['a', 'c', 'a', 'c', 'a', 'c']