In [1]:
from IPython.display import SVG
import numpy as np
import pandas as pd
from scipy.linalg import expm
import itertools
from scipy.stats import truncexpon
from scipy.stats import expon
import scipy.special
import ast
%load_ext rpy2.ipython

In [2]:
def iter_lst_to_ggplot(iter_lst_2):
    dat = pd.DataFrame(columns = ['id', 'seg_y', 'seg_xmin', 'seg_xmax', 'dot_color_left', 'dot_color_right'])
    for i in range(len(iter_lst_2)):
        for j in range(len(iter_lst_2[i])):
            if iter_lst_2[i][j][0] == 0:
                xmin = 0
                xmax = 1
                left = '-'
                right = str(iter_lst_2[i][j][1])
            elif iter_lst_2[i][j][1] == 0:
                xmin = -1
                xmax = 0
                left = str(iter_lst_2[i][j][0])
                right = '-'
            else:
                xmin = -1
                xmax = 1
                left = str(iter_lst_2[i][j][0])
                right = str(iter_lst_2[i][j][1])
            dat.loc[len(dat)] = [i, j, xmin, xmax, left, right]
    return dat

In [3]:
%%R

library(tidyverse)

# This function returns a plot of the different states
plot_states <- function(dat) {
    plt <- as_tibble(dat) %>%
        mutate(
            dot_color_left = ifelse(dot_color_left == '-', NA, dot_color_left),
            dot_color_right = ifelse(dot_color_right == '-', NA, dot_color_right)
        ) %>%
        ggplot() +
        geom_segment(aes(x = seg_xmin, xend = seg_xmax,
                         y = seg_y, yend = seg_y)) +
        geom_point(aes(x = seg_xmin, y = seg_y, 
                       color = is.na(dot_color_left), 
                       fill = dot_color_left,
                  shape = dot_color_left == 1),
                  size = 4) +
        geom_point(aes(x = seg_xmax, y = seg_y, 
                       color = is.na(dot_color_right), 
                       fill = dot_color_right,
                   shape = dot_color_right == 1),
                   size = 4) +
        theme_void() +
        facet_wrap(~id, scales = 'free') +
        scale_fill_manual(
            na.value = 'transparent',
            values = c('black', 'white', 'purple', 'yellow', 'orange', 'green', 'black'),
            breaks = as.character(c(1, 2, 3, 4, 5, 6, 7))
        ) +
        scale_color_manual(
            values = c('black', 'transparent')
        ) +
        scale_shape_manual(values=c(24, 21)) +
        scale_x_continuous(expand = c(0.2, 0.2)) +
        scale_y_continuous(expand = c(0.2, 0.2)) +
        theme(legend.position = 'none',
              panel.border = element_rect(colour = "black", fill = NA, size = 1),
              axis.line=element_blank(),
              panel.background=element_blank(),panel.grid.major=element_blank(),
              panel.grid.minor=element_blank(),plot.background=element_blank()) 
    plt
}

R[write to console]: ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──

R[write to console]: ✔ ggplot2 3.3.5     ✔ purrr   0.3.4
✔ tibble  3.1.5     ✔ dplyr   1.0.7
✔ tidyr   1.1.4     ✔ stringr 1.4.0
✔ readr   2.0.2     ✔ forcats 0.5.1

R[write to console]: ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()



In [4]:
def get_trans_mat(iter_lst, coal, rho):
    """
    This function returns a transition matrix given a 
    list of states, a coalescent rate and a recombination
    rate. 
    """
    trans_mat = np.full((len(iter_lst), len(iter_lst)), 0.0)
    for i in range(len(iter_lst)):
        new_lst = recombination(iter_lst[i])
        idx_lst = [iter_lst.index(j) for j in new_lst]
        for j in idx_lst:
            trans_mat[i,j]=rho
        new_lst = coalescent(iter_lst[i])
        idx_lst = [iter_lst.index(j) for j in new_lst]
        for j in idx_lst:
            trans_mat[i,j]=coal
    for i in range(len(iter_lst)):
        trans_mat[i,i]=-sum(trans_mat[i])
    return trans_mat

In [5]:
def combine_states(iter_lst_a, iter_lst_b, probs_a, probs_b):
    """
    Given two lists of states and their probabilities, this
    function returns a list of combined states and their
    probabilities. 
    """
    iter_lst_ab = []
    probs_ab = []
    for i in range(len(iter_lst_a)):
        for j in range(len(iter_lst_b)):
            iter_lst_ab.append(sorted(iter_lst_a[i]+iter_lst_b[j]))
            probs_ab.append(probs_a[i]*probs_b[j])
    # Define new data frame
    df = pd.DataFrame()
    # Save names of state
    df['name'] = [str(i) for i in iter_lst_ab]
    # Save probabilities
    df['value'] = probs_ab
    # Group by state and sum probabilities
    df = df.groupby("name", as_index=False).sum()
    return list(df['name']), list(df['value'])

In [6]:
def trans_mat_num(trans_mat, coal, rho):
    """
    This function returns a transition matrix given a 
    string matrix whose values are either '0', or 'R'
    or 'C' preceeded by a number, corresponding to a
    multiplication factor for the recombination and 
    the coalescence rate, respectively. The user can
    specify these two numerical rates. The function
    calculates the rates in the diagonals as  
    (-1)*rowSums
    """
    num_rows, num_cols = trans_mat.shape
    trans_mat_num = np.full((num_rows, num_cols), 0.0)
    for i in range(num_rows):
        for j in range(num_cols):
            if trans_mat[i,j] == '0':
                trans_mat_num[i,j] = 0.0
            else:
                trans_mat_num[i,j] = int(trans_mat[i,j][0])*(coal if trans_mat[i,j][1]=='C' else rho)
    for i in range(num_rows):
        trans_mat_num[i,i]=-sum(trans_mat_num[i])
    return trans_mat_num

In [7]:
def cutpoints_AB(n_int_AB, t_AB, coal_AB):
    # Define probabilities for quantiles
    quantiles_AB = np.array(list(range(n_int_AB+1)))/n_int_AB
    # Define truncexpon shape parameters
    lower, upper, scale = 0, t_AB, 1/coal_AB
    # Get quantiles
    cut_AB = truncexpon.ppf(quantiles_AB, b=(upper-lower)/scale, loc=lower, scale=scale)
    return cut_AB 

In [8]:
def cutpoints_ABC(n_int_ABC, coal_ABC):
    # Define probabilities for quantiles
    quantiles_AB = np.array(list(range(n_int_ABC+1)))/n_int_ABC
    # Get quantiles
    cut_ABC = expon.ppf(quantiles_AB, scale=1/coal_ABC)
    return cut_ABC

In [9]:
def get_ABC(trans_mat, times, omegas):
    """
    This function calculates the relative starting probabilities
    for the three-sequence CTMC.
    
    Parameters
    ----------
    trans_mat : numpy array
        The transition rate matrix of the two-sequence CTMC
    times : list of numbers
        Time intervals for each matrix multiplication
    omegas : list of lists
        Sets of states for each matrix multiplication
    """
    # Calculate first multiplication
    g = expm(trans_mat*times[0])[omegas[0]][:,omegas[1]]
    # For each of the remaining omegas
    for i in range(1, len(times)):
        # Perform multiplication
        g = g @ expm(trans_mat*times[i])[omegas[i]][:,omegas[i+1]]
    # Return a numpy array that contains the probabilities in the right order.
    return g

In [1379]:
def get_ABC_inf(trans_mat, times, omegas, coal):
    """
    This function calculates the relative starting probabilities
    for the three-sequence CTMC.
    
    Parameters
    ----------
    trans_mat : numpy array
        The transition rate matrix of the two-sequence CTMC
    times : list of numbers
        Time intervals for each matrix multiplication
    omegas : list of lists
        Sets of states for each matrix multiplication
    """
    g = get_ABC(trans_mat, times[:-1], omegas[:-1]) * (1-np.exp(-times[-1]*coal))
    return g

In [1368]:
def get_ordered(p_ABC, omega_end, omega_tot):
    return np.array([p_ABC[omega_end.index(j)] if j in omega_end else 0 for j in omega_tot])

In [1369]:
def get_times(cut, intervals):
    return [cut[intervals[i+1]]-cut[intervals[i]] for i in range(len(intervals)-1)]

In [1370]:
def get_tab_AB(state_space_AB, trans_mat_AB, cut_AB, pi_AB):
    
    n_int_AB = len(cut_AB)-1
    
    # Get flatten list of states, where even-indexed numbers (0, 2, ...)
    # represent the left-side coalescence states and odd-indexed numbers
    # (1, 3, ...) represent right-side coalescence.
    flatten = [list(sum(i, ())) for i in state_space_AB]
    # Get the index of all states where there is not a 2 (no coalescent)
    omega_B = [i for i in range(9) if 2 not in flatten[i]]
    # Get the index of all states where there is a 2 on left but not on right
    omega_L = [i for i in range(9) if (2 in flatten[i][::2]) and (2 not in flatten[i][1::2])]
    # Get the index of all states where there is a 2 on right but not on left
    omega_R = [i for i in range(9) if (2 not in flatten[i][::2]) and (2 in flatten[i][1::2])]
    # Get the index of all states where there is a 2 on left and right
    omega_E = [i for i in range(9) if (2 in flatten[i][::2]) and (2 in flatten[i][1::2])]
    omega_tot_AB = [i for i in range(9)]
    
    tab = np.zeros((n_int_AB*n_int_AB+n_int_AB*2+1, 9))
    tab_names = []
    acc = 0
    
    p_ABC = pi_AB @ get_ABC(trans_mat_AB, [cut_AB[-1]-cut_AB[0]], [omega_tot_AB, omega_B])
    tab[acc] = get_ordered(p_ABC, omega_B, omega_tot_AB)
    tab_names.append((('D'), ('D')))
    acc += 1
    
    
    for L in range(n_int_AB):
        times = get_times(cut_AB, [0, L, L+1, -1])
        omegas = [omega_tot_AB, omega_B, omega_L, omega_L]
        p_ABC = pi_AB @ get_ABC(trans_mat_AB, times, omegas)
        tab[acc] = get_ordered(p_ABC, omega_L, omega_tot_AB)
        tab_names.append(((0, L), ('D')))
        acc += 1
    for R in range(n_int_AB):
        times = get_times(cut_AB, [0, R, R+1, -1])
        omegas = [omega_tot_AB, omega_B, omega_R, omega_R]
        p_ABC = pi_AB @ get_ABC(trans_mat_AB, times, omegas)
        tab[acc] = get_ordered(p_ABC, omega_R, omega_tot_AB)
        tab_names.append((('D'), (0, R)))
        acc += 1
        
    
    for R in range(n_int_AB):
        for L in range(n_int_AB):
            if R == L:
                times = get_times(cut_AB, [0, L, L+1, -1])
                omegas = [omega_tot_AB, omega_B, omega_E, omega_E]
                p_ABC = pi_AB @ get_ABC(trans_mat_AB, times, omegas)
                tab[acc] = get_ordered(p_ABC, omega_E, omega_tot_AB)
            elif L < R:
                times = get_times(cut_AB, [0, L, L+1, R, R+1, -1])
                omegas = [omega_tot_AB, omega_B, omega_L, omega_L, omega_E, omega_E]
                p_ABC = pi_AB @ get_ABC(trans_mat_AB, times, omegas)
                tab[acc] = get_ordered(p_ABC, omega_E, omega_tot_AB)
            elif L > R:
                times = get_times(cut_AB, [0, R, R+1, L, L+1, -1])
                omegas = [omega_tot_AB, omega_B, omega_R, omega_R, omega_E, omega_E]
                p_ABC = pi_AB @ get_ABC(trans_mat_AB, times, omegas)
                tab[acc] = get_ordered(p_ABC, omega_E, omega_tot_AB)
            tab_names.append(((0, L), (0, R)))
            acc += 1
            
    
    return tab_names, tab

In [1371]:
def get_tab_ABC(state_space_ABC, trans_mat_ABC, coal_ABC, cut_ABC, pi_ABC, names_tab_AB, n_int_AB):
    
    n_int_ABC = len(cut_ABC)-1
    
    # Get flatten list of states, where even-indexed numbers (0, 2, ...)
    # represent the left-side coalescence states and odd-indexed numbers
    # (1, 3, ...) represent right-side coalescence.
    flatten = [list(sum(i, ())) for i in state_space_ABC]
    omega_tot_ABC = [i for i in range(31)]
    omega_00 = [i for i in range(31) if all(x not in [2, 3] for x in flatten[i])]
    omega_10 = [i for i in range(31) if (2 in flatten[i][::2]) and (all(x not in [2, 3] for x in flatten[i][1::2]))]
    omega_01 = [i for i in range(31) if (all(x not in [2, 3] for x in flatten[i][::2])) and (2 in flatten[i][1::2])]
    omega_11 = [i for i in range(31) if (2 in flatten[i][::2]) and (2 in flatten[i][1::2])]
    omega_12 = [i for i in range(31) if (2 in flatten[i][::2]) and (3 in flatten[i][1::2])]
    omega_21 = [i for i in range(31) if (3 in flatten[i][::2]) and (2 in flatten[i][1::2])]
    omega_22 = [i for i in range(31) if (3 in flatten[i][::2]) and (3 in flatten[i][1::2])]
    omega_20 = [i for i in range(31) if (3 in flatten[i][::2]) and (all(x not in [2, 3] for x in flatten[i][1::2]))]
    omega_02 = [i for i in range(31) if (all(x not in [2, 3] for x in flatten[i][::2])) and (3 in flatten[i][1::2])]
    
        
    n_markov_states = n_int_AB**2+n_int_ABC+3*scipy.special.comb(n_int_ABC, 2, exact = True)
    
    tab = []
    acc = 0
    # V0 -> V0
    for l in range(n_int_AB):
        for r in range(n_int_AB):
            
            cond = [i == ((0, l),(0, r)) for i in names_tab_AB]
            pi = pi_ABC[cond]
            acc += pi.sum()
            
            #print('-----')
            #print((pi).sum())
            #print()
            for L in range(n_int_ABC):
                for R in range(n_int_ABC):
                    # print((0, l, L), (0, r, R))
                    if L < R:                        
                        times_ABC = get_times(cut_ABC, [0, L, L+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_11, omega_21, omega_21, omega_22]
                        p_ABC = get_ABC_inf(trans_mat_ABC, times_ABC, omegas_ABC, coal_ABC)
                        
                        #print((0, l, L), (0, r, R))
                        #print((p_ABC[[i for i in range(len(pi[0])) if pi[0][i] != 0]]).sum())
                        #print((np.array([[0 if i == 0 else 1 for i in pi[0]]])@p_ABC).sum())
                        #print((pi).sum())
                        #print((pi@p_ABC).sum()/2)
                        
                        #print((0, r, R), (0, l, L))
                        #print((p_ABC[[i for i in range(len(pi[0])) if pi[0][i] != 0]]).sum())
                        #print((np.array([[0 if i == 0 else 1 for i in pi[0]]])@p_ABC).sum())
                        #print((pi).sum())
                        #print((pi@p_ABC).sum()/2)
                        
                        tab.append([(0, l, L), (0, r, R), (pi@p_ABC).sum()])
                        tab.append([(0, r, R), (0, l, L), (pi@p_ABC).sum()])
                    elif L == R:
                        times_ABC = get_times(cut_ABC, [0, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_11, omega_22]
                        p_ABC = get_ABC_inf(trans_mat_ABC, times_ABC, omegas_ABC, coal_ABC)
                        
                        #print((0, l, L), (0, r, R))
                        #print((p_ABC[[i for i in range(len(pi[0])) if pi[0][i] != 0]]).sum())
                        #print((np.array([[0 if i == 0 else 1 for i in pi[0]]])@p_ABC).sum())
                        #print((pi).sum())
                        #print((pi@p_ABC).sum())
                        
                        tab.append([(0, l, L), (0, r, R), (pi@p_ABC).sum()])
                    else:
                        continue
    
        
    print('----')
    print(acc) 
    print(sum([i[2] for i in tab])) 
    print('----')
    
    # print((n_int_AB*n_int_ABC)**2)
    # print(len(tab)) 
    
    return
    
        
    acc2 = 0
    acc3 = 0
    # V0 -> deep
    for l in range(n_int_AB):
        cond = [i == ((0, l),'D') for i in names_tab_AB]
        pi = pi_ABC[cond]
        pi = pi/pi.sum()
        for L in range(n_int_ABC):
            for r in range(n_int_ABC):
                for R in range(r, n_int_ABC):
                    if L < r < R:
                        times_ABC = get_times(cut_ABC, [0, L, L+1, r, r+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_20, omega_20, omega_21, omega_21, omega_22]
                    elif L == r < R:
                        times_ABC = get_times(cut_ABC, [0, L, L+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_21, omega_21, omega_22]
                    elif r < L < R:
                        times_ABC = get_times(cut_ABC, [0, r, r+1, L, L+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_11, omega_11, omega_21, omega_21, omega_22]
                    elif r < L == R:
                        times_ABC = get_times(cut_ABC, [0, r, r+1, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_11, omega_11, omega_22]
                    elif r < R < L:
                        times_ABC = get_times(cut_ABC, [0, r, r+1, R, R+1, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_11, omega_11, omega_12, omega_12, omega_22]
                    elif L < r == R:
                        times_ABC = get_times(cut_ABC, [0, L, L+1, r, r+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_20, omega_20, omega_22]
                    elif L == r == R:
                        times_ABC = get_times(cut_ABC, [0, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_22]
                    elif r == R < L:
                        times_ABC = get_times(cut_ABC, [0, r, r+1, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_10, omega_12, omega_12, omega_22]
                    else:
                        continue

                    p_ABC = get_ABC_inf(trans_mat_ABC, times_ABC, omegas_ABC, coal_ABC)
                    if r < R:
                        acc3 += (pi@p_ABC).sum()*2
                        [tab.append([(0, l, L), (i, r, R), (pi@p_ABC).sum()/3]) for i in range(1, 4)]
                        [tab.append([(i, r, R), (0, l, L), (pi@p_ABC).sum()/3]) for i in range(1, 4)]
                    elif r == R:
                        acc3 += (pi@p_ABC).sum()
                        tab.append([(0, l, L), (4, r, R), (pi@p_ABC).sum()])
                        tab.append([(4, r, R), (0, l, L), (pi@p_ABC).sum()])
    
    
    
    
    # print((n_int_AB*n_int_ABC) * (3*scipy.special.comb(n_int_ABC, 2, exact = True)+n_int_ABC)*2)
    # print(len(tab) - (n_int_AB*n_int_ABC)**2)
    
    #print(acc3)
    #print(acc2)
    print('----')
    print(sum([i[2] for i in tab]))
    print('----')
    
    
    acc4 = 0                    
    # deep -> deep
    cond = [i == ('D','D') for i in names_tab_AB]
    pi = pi_ABC[cond]
    pi = pi/pi.sum()
    for l in range(n_int_ABC):
        for L in range(l, n_int_ABC):
            for r in range(n_int_ABC):
                for R in range(r, n_int_ABC): 
                    if l < L < r < R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, L, L+1, r, r+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_10, omega_10, omega_20, omega_20, omega_21, omega_21, omega_22]
                    elif l < L == r < R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, L, L+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_10, omega_10, omega_21, omega_21, omega_22]
                    elif l == r < L < R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, L, L+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_11, omega_11, omega_21, omega_21, omega_22]
                    elif l < r < L < R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, r, r+1, L, L+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_10, omega_10, omega_11, omega_11, omega_21, omega_21, omega_22]
                    elif r < l < L < R:
                        times_ABC = get_times(cut_ABC, [0, r, r+1, l, l+1, L, L+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_01, omega_01, omega_11, omega_11, omega_21, omega_21, omega_22]
                    elif l == r < L == R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_11, omega_11, omega_22]
                    elif l < r < L == R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, r, r+1, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_10, omega_10, omega_11, omega_11, omega_22]
                    elif l == r == L == R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_22]
                    elif l == L < r == R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, r, r+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_20, omega_20, omega_22]
                    elif l == L < r < R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, r, r+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_20, omega_20, omega_21, omega_21, omega_22]
                    elif l == L == r < R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_21, omega_21, omega_22]
                    elif l < L == r == R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, L, L+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_10, omega_10, omega_22]
                    elif l < L < r == R:
                        times_ABC = get_times(cut_ABC, [0, l, l+1, L, L+1, r, r+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_10, omega_10, omega_20, omega_20, omega_22]
                    elif r < l == L < R:
                        times_ABC = get_times(cut_ABC, [0, r, r+1, l, l+1, R, R+1])
                        omegas_ABC = [omega_tot_ABC, omega_00, omega_01, omega_01, omega_21, omega_21, omega_22]
                    else:
                        continue
                    
                    p_ABC = get_ABC_inf(trans_mat_ABC, times_ABC, omegas_ABC, coal_ABC)
                    if l == r == L == R:
                        acc4 += (pi@p_ABC).sum()
                        tab.append([(4, l, L), (4, r, R), (pi@p_ABC).sum()])
                    elif l == r < L == R:
                        acc4 += (pi@p_ABC).sum()
                        [[tab.append([(i, l, L), (j, r, R), (pi@p_ABC).sum()/9]) for j in range (1, 4)] for i in range(1, 4)]
                    elif l == L < r == R:
                        acc4 += (pi@p_ABC).sum()*2
                        tab.append([(4, l, L), (4, r, R), (pi@p_ABC).sum()])
                        tab.append([(4, r, R), (4, l, L), (pi@p_ABC).sum()])
                    elif l == L:
                        acc4 += (pi@p_ABC).sum()*2
                        [tab.append([(4, l, L), (i, r, R), (pi@p_ABC).sum()/3]) for i in range (1, 4)]
                        [tab.append([(i, r, R), (4, l, L), (pi@p_ABC).sum()/3]) for i in range (1, 4)]
                    elif r == R:
                        acc4 += (pi@p_ABC).sum()*2
                        [tab.append([(i, l, L), (4, r, R), (pi@p_ABC).sum()/3]) for i in range (1, 4)]
                        [tab.append([(4, r, R), (i, l, L), (pi@p_ABC).sum()/3]) for i in range (1, 4)]
                    else:
                        acc4 += (pi@p_ABC).sum()*2
                        [[tab.append([(i, l, L), (j, r, R), (pi@p_ABC).sum()/9]) for j in range (1, 4)] for i in range(1, 4)]
                        [[tab.append([(j, r, R), (i, l, L), (pi@p_ABC).sum()/9]) for j in range (1, 4)] for i in range(1, 4)]
    
    print('----')
    print(sum([i[2] for i in tab]))
    print('----')
    
    # print(n_markov_states**2)
                        
    return pd.DataFrame(tab)

In [1372]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  1,   1)

----
0.4322367818263902
0.43223678182639025
----


In [1373]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  1,   2)

----
0.4322367818263902
0.5195738705021714
----


In [1374]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  1,   3)

----
0.4322367818263902
0.5497039847725773
----


In [1375]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  2,   1)

----
0.4322367818263902
0.4322367818263902
----


In [1376]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  2,   2)

----
0.4322367818263902
0.5195738705021713
----


In [1377]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  2,   3)

----
0.4322367818263902
0.5497039847725771
----


It is inflated when n_int_ABC increases, but it is unaffected by n_int_AB

In [1265]:
0.16666666666666666+0.033676406215262945*2+0.09931385423614086

0.3333333333333334

In [1227]:
0.2979415627084226+0.20205843729157769+0.5

1.0000000000000002

In [1196]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  2,   3)

(0, 0, 0) (0, 0, 0)
0.9999999999999996
0.1440789272754634
(0, 0, 0) (0, 0, 1)
0.9999999999999996
0.03903506444584945
(0, 0, 1) (0, 0, 0)
0.9999999999999996
0.03903506444584945
(0, 0, 0) (0, 0, 2)
0.9999999999999996
0.039035064445849435
(0, 0, 2) (0, 0, 0)
0.9999999999999996
0.039035064445849435
(0, 0, 1) (0, 0, 1)
0.9999999999999996
0.2062281448588368
(0, 0, 1) (0, 0, 2)
0.9999999999999996
0.07625626294350059
(0, 0, 2) (0, 0, 1)
0.9999999999999996
0.07625626294350059
(0, 0, 2) (0, 0, 2)
0.9999999999999996
0.18207490041952837
(0, 0, 0) (0, 1, 0)
0.9999999999999996
0.1440789272754634
(0, 0, 0) (0, 1, 1)
0.9999999999999996
0.03903506444584945
(0, 1, 1) (0, 0, 0)
0.9999999999999996
0.03903506444584945
(0, 0, 0) (0, 1, 2)
0.9999999999999996
0.039035064445849435
(0, 1, 2) (0, 0, 0)
0.9999999999999996
0.039035064445849435
(0, 0, 1) (0, 1, 1)
0.9999999999999996
0.2062281448588368
(0, 0, 1) (0, 1, 2)
0.9999999999999996
0.07625626294350059
(0, 1, 2) (0, 0, 1)
0.9999999999999996
0.076256262943500

In [851]:
def load_trans_mat(n_seq):
    df = pd.read_csv('../02_state_space/trans_mats/trans_mat_simple_'+str(n_seq)+'.csv')
    d = {'names': pd.concat([df['from_str'], df['to_str']]),
     'values': pd.concat([df['from'], df['to']])}
    df_2 = pd.DataFrame(data=d).drop_duplicates().sort_values(by=['values'])
    df_1 = df[['value', 'from', 'to']].pivot(index='from',columns='to',values='value').fillna('0')
    df_1.columns.name = None
    df_1 = df_1.reset_index().iloc[:, 1:]
    return np.array(df_1), list(df_2['names'])


In [807]:
def get_HMM_trans_mat(t_A,    t_B,    t_AB,    t_C, 
                      rho_A,  rho_B,  rho_AB,  rho_C,  rho_ABC, 
                      coal_A, coal_B, coal_AB, coal_C, coal_ABC,
                      n_int_AB, n_int_ABC):
    
    ####################################
    ### Load state-space information ###
    ####################################
    
    (trans_mat_1, state_space_1) = load_trans_mat(1)
    state_space_A = [ast.literal_eval(i) for i in state_space_1]
    (trans_mat_2, state_space_2) = load_trans_mat(2)
    state_space_AB = [ast.literal_eval(i) for i in state_space_2]
    (trans_mat_3, state_space_3) = load_trans_mat(3)
    state_space_ABC = [ast.literal_eval(i) for i in state_space_3]
    
    ##########################
    ### One-sequence CTMCs ###
    ##########################
    
    # These are (2x2) matrices
    trans_mat_A = trans_mat_num(trans_mat_1, coal_A, rho_A)
    trans_mat_B = trans_mat_num(trans_mat_1, coal_B, rho_B)
    trans_mat_C = trans_mat_num(trans_mat_1, coal_C, rho_C)
    
    # These are (1x2) vectors
    final_A = expm(trans_mat_A*t_A)[0]
    final_B = expm(trans_mat_B*t_B)[0]
    final_C = expm(trans_mat_C*t_C)[0]
    
    (comb_AB_name, comb_AB_value) = combine_states(state_space_A, state_space_A, final_A, final_B)
    pi_AB = [comb_AB_value[comb_AB_name.index(i)] if i in comb_AB_name else 0 for i in state_space_2]
        
    #########################
    ### Two-sequence CTMC ###
    #########################
    
    cut_AB = cutpoints_AB(n_int_AB, t_AB, coal_AB)
    trans_mat_AB = trans_mat_num(trans_mat_2, coal_AB, rho_AB)
    
    (names_tab_AB, tab_AB) = get_tab_AB(state_space_AB, trans_mat_AB, cut_AB, pi_AB)
        
    def comb_wrapper(x):
        (comb_ABC_name, comb_ABC_value) = combine_states(state_space_AB, state_space_A, x, final_C)
        pi_ABC = [comb_ABC_value[comb_ABC_name.index(i)] if i in comb_ABC_name else 0 for i in state_space_3]
        return pi_ABC
        
    pi_ABC = np.apply_along_axis(comb_wrapper, axis=1, arr=tab_AB)
    
    ###########################
    ### Three-sequence CTMC ###
    ###########################
    
    cut_ABC = cutpoints_ABC(n_int_ABC, coal_ABC)
    trans_mat_ABC = trans_mat_num(trans_mat_3, coal_ABC, rho_ABC)
    
    tab_2 = get_tab_ABC(state_space_ABC, trans_mat_ABC, coal_ABC, cut_ABC, pi_ABC, names_tab_AB, n_int_AB)
    
    return tab_2
    
    

In [808]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  1,   1)

0.4322367818263902
0.43223678182639025
0.19988377700216708
0.19988377700216708


Unnamed: 0,0,1,2
0,"(0, 0, 0)","(0, 0, 0)",0.432237
1,"(0, 0, 0)","(4, 0, 0)",0.099942
2,"(4, 0, 0)","(0, 0, 0)",0.099942
3,"(4, 0, 0)","(4, 0, 0)",0.167996


In [809]:
get_HMM_trans_mat(0.1, 0.1, 1, 2, 
                  2,   1,   3, 1, 1, 
                  1,   0.5, 1, 1, 1, 
                  1,   2)

0.4322367818263902
0.43223678182639036
0.17196868644758312
0.19988377700216708


Unnamed: 0,0,1,2
0,"(0, 0, 0)","(0, 0, 0)",0.216118
1,"(0, 0, 0)","(0, 0, 1)",0.043669
2,"(0, 0, 1)","(0, 0, 0)",0.043669
3,"(0, 0, 1)","(0, 0, 1)",0.128781
4,"(0, 0, 0)","(4, 0, 0)",0.024985
5,"(4, 0, 0)","(0, 0, 0)",0.024985
6,"(0, 0, 0)","(1, 0, 1)",0.017004
7,"(0, 0, 0)","(2, 0, 1)",0.017004
8,"(0, 0, 0)","(3, 0, 1)",0.017004
9,"(1, 0, 1)","(0, 0, 0)",0.017004


0.0780701288916989

In [465]:
0.21611839091319512+0.08733708867578102+0.1287813022374142

0.43223678182639036

In [203]:
0.1690674674997399+0.06792790037960195+0.06792790037960195+0.1273135135674464

0.4322367818263902