$
\begin{align}
&P(L=l, R=r, V=v, H=h)= \\
&=\begin{cases}
       \sum_{i\in \text{no coal}} & \text{if $l=r$ and $v=h$}\\
    \end{cases}       
\end{align}
$

In [155]:
from IPython.display import SVG
import numpy as np
import pandas as pd
from scipy.linalg import expm
%load_ext rpy2.ipython

The rpy2.ipython extension is already loaded. To reload it, use:
  %reload_ext rpy2.ipython


In [17]:
def recombination(i):
    """
    This function returns all possible states after a 
    single recombination event given an initial state
    """
    # Create empty list
    new_lst = []
    # For each block index in the list
    for block in range(len(i)):
        # If the site is linked
        if 0 not in i[block]:
            # Remove block
            lst = i[:block]+i[block+1:]
            # Split block
            lst = lst + [(i[block][0], 0), (0, i[block][1])]
            # Append to list
            new_lst.append(sorted(lst))
    return new_lst

In [18]:
def coalescent(i):
    """
    This function returns all possible states after a 
    single coalescence event given an initial state. 
    The returned list is somewhat ordered, so that the 
    first states correspond to reversible coalescence 
    events, and the last to non-reversible ones.
    """
    # Get unique values per site
    n_0 = set([0]+list(sum(i, ())))
    reversible = []
    nonreversible = []
    # For each combination of blocks
    for b1 in range(len(i)):
        for b2 in range(b1+1, len(i)):
            # Create new block by merging
            add_lst = [(i[b1][0]+i[b2][0], i[b1][1]+i[b2][1])]
            # For each index in i
            for b3 in range(len(i)):
                # If the index is not in among the merged blocks
                if b3 not in [b1, b2]:
                    # Append to list
                    add_lst.append(i[b3])
            # If the unique values are different from the 
            # original block
            if set([0]+list(sum(add_lst, ()))) == n_0:
                # Append to reversible coalescence list
                reversible.append(sorted(add_lst))
            else:
                # Append to non-reversible coalescence list
                nonreversible.append(sorted(add_lst))
    l = reversible+nonreversible
    return l

In [19]:

def idx(i):
    j = list(sum(i, ()))
    return -(len(j)-j.count(0))

def get_states(lst):
    """
    Given a list of states, this function finds all possible
    recombination and coalescence states iteratively.
    """
    # Sort block per states just in case
    all_states = [sorted(i) for i in lst]
    # Define pool of states to be transformed
    state_pool = all_states
    # While the pool is non-empty
    while len(state_pool) > 0:
        # Define result list per iteration
        iter_states = []
        # For each of the states in the pool
        for i in state_pool:
            # Save coalescent transformation
            iter_states = iter_states+coalescent(i)
            # Save recombination transformation
            iter_states = iter_states+recombination(i)
        # Keep unique values in the original order
        iter_states = sorted(list(list(i) for i in set(tuple(row) for row in iter_states)), key = iter_states.index)
        # Define pool for next round
        state_pool = [i for i in iter_states if i not in all_states]
        # Add pool to overall list
        all_states += state_pool
        
    all_states = sorted(all_states, key = idx)
    return all_states

import multiprocessing as mp
def get_states_parallel(lst):
    """
    Given a list of states, this function finds all possible
    recombination and coalescence states iteratively and in
    parallel.
    """
    all_states = [sorted(i) for i in lst]
    state_pool = all_states
    while len(state_pool) > 0:
        print(len(all_states), end = '\r')
        iter_states = []
        pool = mp.Pool(mp.cpu_count())
        iter_states = iter_states+[item for sublist in pool.map(coalescent, state_pool) for item in sublist]
        iter_states = iter_states+[item for sublist in pool.map(recombination, state_pool) for item in sublist]
        pool.close()
        iter_states = sorted(list(list(i) for i in set(tuple(row) for row in iter_states)), key = iter_states.index)
        state_pool = [i for i in iter_states if i not in all_states]
        all_states += state_pool
    all_states = sorted(all_states, key = idx)
    return all_states


In [20]:
def colored(r, g, b, text):
    """
    Given some color coordinates in RGB and 
    some text, this function returns the text in
    the right format for printing in color. 
    """
    return "\033[38;2;{};{};{}m{}\033[38;2;000;000;000m".format(r, g, b, text)

dct = {
    1:('204', '000', '000'),
    2:('000', '000', '204'),
    4:('204', '204', '000'),
    3:('204', '000', '204'),
    5:('204', '102', '000'),
    6:('000', '204', '000'),
    7:('000', '000', '000'),
}

for i in dct:
    entry = dct[i]
    print('-'+colored(entry[0], entry[1], entry[2], i))


-[38;2;204;000;000m1[38;2;000;000;000m
-[38;2;000;000;204m2[38;2;000;000;000m
-[38;2;204;204;000m4[38;2;000;000;000m
-[38;2;204;000;204m3[38;2;000;000;000m
-[38;2;204;102;000m5[38;2;000;000;000m
-[38;2;000;204;000m6[38;2;000;000;000m
-[38;2;000;000;000m7[38;2;000;000;000m


In [21]:
def print_states(i, fill):
    """
    This function transforms a numerical representation
    of an entry of the state space into a string with 
    coloring that can be used to print the entry in a 
    ball-and-stick representation. This function only 
    works for a population size of up to 3 individuals.
    The argument fill is used to specify the height of 
    the print. If fill exceeds the natural height of the
    printed state, then trailing whitespaces will be 
    printed.
    """
    string = ' ______ \n'
    for j in i:
        string += '| '
        if j[0] == 0:
            string += '  '
        else:
            entry = dct[j[0]]
            if j[0] in [1, 2, 4]:
                symbol = '●'
            else:
                symbol = 'X'
            string += (colored(entry[0], entry[1], entry[2], symbol)+'-')
        if j[1] == 0:
            string += '  '
        else:
            if j[1] in [1, 2, 4]:
                symbol = '●'
            else:
                symbol = 'X'
            entry = dct[j[1]]
            string += ('-'+colored(entry[0], entry[1], entry[2], symbol))
        string += ' |\n'
    string += ' ‾‾‾‾‾‾ '
    string += '\n        '*(fill-len(string.split('\n')))
    return string

print(print_states([(0, 1), (0, 2), (0, 4), (1, 0), (6, 0)], 6))

 ______ 
|   -[38;2;204;000;000m●[38;2;000;000;000m |
|   -[38;2;000;000;204m●[38;2;000;000;000m |
|   -[38;2;204;204;000m●[38;2;000;000;000m |
| [38;2;204;000;000m●[38;2;000;000;000m-   |
| [38;2;000;204;000mX[38;2;000;000;000m-   |
 ‾‾‾‾‾‾ 


In [22]:
def print_all_states(iter_lst, n_col):
    """
    This function re-formats the output of print_states
    for plotting several states in the same row/column. The 
    arguments are a list of states and the number of
    columns for printing. The returned object is a list 
    containing the string for each of the rows.  
    """
    # Define empty string holder list
    new_string = []
    # For each row
    for row in range(0, len(iter_lst), n_col):
        
        # Define index of the final state of that row
        end = row+n_col
        if end > len(iter_lst):
            end = len(iter_lst)
        
        # Save the height of the current row
        max_len = max([len(i) for i in iter_lst[row:end]])+2
        
        # Print states in right format with trailing whitespaces
        x = []
        for i in iter_lst[row:end]:
            x.append(print_states(i, max_len))
        
        # Define new empty string
        string = ''
        # For each row within each printed state
        for max_depth in range(max_len):
            # For each printed state
            for i in x:
                # Add row
                string += i.split('\n')[max_depth]
            string += '\n'
        # Append new string to final result
        new_string.append(string[:-2])
    
    return new_string


In [23]:
def iter_lst_to_ggplot(iter_lst_2):
    dat = pd.DataFrame(columns = ['id', 'seg_y', 'seg_xmin', 'seg_xmax', 'dot_color_left', 'dot_color_right'])
    for i in range(len(iter_lst_2)):
        for j in range(len(iter_lst_2[i])):
            if iter_lst_2[i][j][0] == 0:
                xmin = 0
                xmax = 1
                left = '-'
                right = str(iter_lst_2[i][j][1])
            elif iter_lst_2[i][j][1] == 0:
                xmin = -1
                xmax = 0
                left = str(iter_lst_2[i][j][0])
                right = '-'
            else:
                xmin = -1
                xmax = 1
                left = str(iter_lst_2[i][j][0])
                right = str(iter_lst_2[i][j][1])
            dat.loc[len(dat)] = [i, j, xmin, xmax, left, right]
    return dat

In [24]:
state_3 = [[(0, 1), (0, 2), (0, 4), (1, 0), (2, 0), (4, 0)]]

iter_lst_3 = get_states(state_3)
len(iter_lst_3)

203

In [25]:
trans_mat_3 = np.full((len(iter_lst_3), len(iter_lst_3)), '0')
for i in range(len(iter_lst_3)):
    new_lst = recombination(iter_lst_3[i])
    idx_lst = [iter_lst_3.index(j) for j in new_lst]
    for j in idx_lst:
        trans_mat_3[i,j]='R'
    new_lst = coalescent(iter_lst_3[i])
    idx_lst = [iter_lst_3.index(j) for j in new_lst]
    for j in idx_lst:
        trans_mat_3[i,j]='C'

In [26]:
trans_mat_3

array([['0', 'C', 'C', ..., '0', '0', '0'],
       ['R', '0', '0', ..., '0', '0', '0'],
       ['R', '0', '0', ..., '0', '0', '0'],
       ...,
       ['0', '0', '0', ..., '0', '0', 'C'],
       ['0', '0', '0', ..., '0', '0', 'C'],
       ['0', '0', '0', ..., '0', 'R', '0']], dtype='<U1')

In [29]:
%%R -i trans_mat_3 

library(tidyverse)

fun_trans <- function(x, coa, rho) {
    if (x == 'R') {
        return(rho)
    } else if (x == 'C') {
        return(coa)
    } else return(0)
}

set_coa_rho <- function(X, coa, rho) {
    
    dat = data.frame(lapply(X, Vectorize(fun_trans), coa = coa, rho = rho))
    dat = as.matrix(dat)
    dat <- matrix(as.numeric(dat),    # Convert to numeric matrix
                  ncol = ncol(X))
    dat[row(dat)==col(dat)] = -rowSums(dat)
    dat                                   
}

a <- set_coa_rho(trans_mat_3, 1, 2)




In [30]:
%%R

a[1:10, 1:10]

      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
 [1,]  -15    1    1    1    1    1    1    1    1     1
 [2,]    2  -12    0    0    0    0    0    0    0     0
 [3,]    2    0  -12    0    0    0    0    0    0     0
 [4,]    2    0    0  -12    0    0    0    0    0     0
 [5,]    2    0    0    0  -12    0    0    0    0     0
 [6,]    2    0    0    0    0  -12    0    0    0     0
 [7,]    2    0    0    0    0    0  -12    0    0     0
 [8,]    2    0    0    0    0    0    0  -12    0     0
 [9,]    2    0    0    0    0    0    0    0  -12     0
[10,]    2    0    0    0    0    0    0    0    0   -12


In [80]:
iter_lst_3[200]

[(0, 3), (7, 4)]

In [34]:
cutpoints_3_2 = []
for i in range(1, len(iter_lst_3)):
    if idx(iter_lst_3[i]) > idx(iter_lst_3[i-1]):
        cutpoints_3_2.append(i)
cutpoints_3_2

[34, 112, 183, 201]

-5

In [120]:
def get_indices(iter_lst_3, cutpoints_3_2, l, r, v, h):
    
    dct = {0:3, 1:3, 2:5, 3:6}
    
    if 0 not in [v, h]:
    
        # Starting states
        a = [[j for j in iter_lst_3[i] if 4 in j and 0 not in j and j != (4, 4)] for i in range(0, cutpoints_3_2[0])] 
        a = [i for i in range(len(a)) if len(a[i]) == 0]

        # No coalescence
        b = list(range(cutpoints_3_2[0]))

        if l != r:

            # 1st coalescent on min
            min_value = min([l, r])
            min_index = [l, r].index(min_value)
            c = [i for i in range(cutpoints_3_2[0], cutpoints_3_2[1]) if dct[[v, h][min_index]] in [j[min_index] for j in iter_lst_3[i]]]


            # No coalescence
            d = [i for i in range(cutpoints_3_2[1], cutpoints_3_2[2]) if 7 in [j[min_index] for j in iter_lst_3[i]]]
            d = c+d

            # 1st coalescent on max
            max_value = max([l, r])
            max_index = [l, r].index(max_value)
            e_min = [i for i in range(cutpoints_3_2[1], cutpoints_3_2[2]) if dct[[v, h][min_index]] in [j[min_index] for j in iter_lst_3[i]]]
            e_7_min = [i for i in range(cutpoints_3_2[2], cutpoints_3_2[3]) if 7 in [j[min_index] for j in iter_lst_3[i]]]
            e = [i for i in range(cutpoints_3_2[1], cutpoints_3_2[3]) if dct[[v, h][max_index]] in [j[max_index] for j in iter_lst_3[i]]]    
            seen = set()
            e = [x for x in e_min+e_7_min+e if x in seen or seen.add(x)]
        
        else:
            
            # Both coalesce
            c_l = [i for i in range(cutpoints_3_2[1], cutpoints_3_2[2]) if dct[v] in [j[0] for j in iter_lst_3[i]]]
            c_r = [i for i in range(cutpoints_3_2[1], cutpoints_3_2[2]) if dct[h] in [j[1] for j in iter_lst_3[i]]]
            seen = set()
            c = [x for x in c_l+c_r if x in seen or seen.add(x)]
            
            d = []
            e = []
    # else:
    #     
    #     if v == h:
    #         a_l = [i for i in range(cutpoints_3_2[1], cutpoints_3_2[2]) if 3 in [j[0] for j in iter_lst_3[i]]]
    #         a_r = [i for i in range(cutpoints_3_2[1], cutpoints_3_2[2]) if 3 in [j[1] for j in iter_lst_3[i]]]
    #         seen = set()
    #         a = [x for x in a_l+a_r if x in seen or seen.add(x)]
    #         
    #         b = []
    #         c = []
    #         d = []
    #         e = []
    #         
    #     else:
    #         if 
        
     
    return (a, b, c, d, e)

In [152]:
def calc_tm(nInt):
    return [0]+[-np.log(1-i/nInt) for i in range(1, nInt+1)]

In [153]:
a = calc_tm(4)
a

  return [0]+[-np.log(1-i/nInt) for i in range(1, nInt+1)]


[0, 0.2876820724517809, 0.6931471805599453, 1.3862943611198906, inf]

In [None]:
expm(tm0[j]*rate_mat)[1,1:3]%*%
expm((tm[j]-tm0[j])*rate_mat)[1:3,4]*
exp(-(tm0[k]-tm[j]))*
# which is the same as 
# expm((tm0[k]-tm[j])*rate_mat)[4,4]*0.5*
(1-exp(-(tm[k]-tm0[k])))*0.5

In [168]:
def set_coa_rho(iter_lst_3, coa, rho):
    trans_mat_3 = np.full((len(iter_lst_3), len(iter_lst_3)), 0)
    for i in range(len(iter_lst_3)):
        new_lst = recombination(iter_lst_3[i])
        idx_lst = [iter_lst_3.index(j) for j in new_lst]
        for j in idx_lst:
            trans_mat_3[i,j]=rho
        new_lst = coalescent(iter_lst_3[i])
        idx_lst = [iter_lst_3.index(j) for j in new_lst]
        for j in idx_lst:
            trans_mat_3[i,j]=coa
    return trans_mat_3


rate_mat = set_coa_rho(iter_lst_3, 1, 2)

rate_mat

array([[0, 1, 1, ..., 0, 0, 0],
       [2, 0, 0, ..., 0, 0, 0],
       [2, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 1],
       [0, 0, 0, ..., 0, 0, 1],
       [0, 0, 0, ..., 0, 2, 0]])

In [189]:
tm

[0, 0.6931471805599453, inf]

In [216]:
nInt = 3
tm = calc_tm(nInt)

rate_mat = set_coa_rho(iter_lst_3, 1, 3)

trans_mat = np.full((nInt*3, nInt*3), 0)
for l in range(1, nInt):
    for v in range(1, 4):
        for r in range(1, nInt):
            for h in range(1, 4):
                (a, b, c, d, e) = get_indices(iter_lst_3, cutpoints_3_2, l, r, v, h)
                if l < r:
                    mult_1 = expm(tm[l-1]*rate_mat)[np.ix_(a,b)]
                    mult_2 = expm((tm[l]-tm[l-1])*rate_mat)[np.ix_(b,c)]
                    mult_3 = expm((tm[r-1]-tm[l])*rate_mat)[np.ix_(c,d)]
                    mult_4 = expm((tm[r]-tm[r-1])*rate_mat)[np.ix_(d,e)]
                    trans_mat[(l-1)*3+(v-1), (r-1)*3+(h-1)] = np.sum(mult_1@mult_2@mult_3@mult_4)
                elif l > r:
                    mult_1 = expm(tm[r-1]*rate_mat)[np.ix_(a,b)]
                    mult_2 = expm((tm[r]-tm[r-1])*rate_mat)[np.ix_(b,c)]
                    mult_3 = expm((tm[l-1]-tm[r])*rate_mat)[np.ix_(c,d)]
                    mult_4 = expm((tm[l]-tm[l-1])*rate_mat)[np.ix_(d,e)]
                    trans_mat[(l-1)*3+(v-1), (r-1)*3+(h-1)] = np.sum(mult_1@mult_2@mult_3@mult_4)
                elif l == r:
                    mult_1 = expm(tm[l-1]*rate_mat)[np.ix_(a,b)]
                    mult_2 = expm((tm[l]-tm[l-1])*rate_mat)[np.ix_(b,c)]
                    trans_mat[(l-1)*3+(v-1), (r-1)*3+(h-1)] = np.sum(mult_1@mult_2)

  return [0]+[-np.log(1-i/nInt) for i in range(1, nInt+1)]


In [217]:
trans_mat

array([[  28,   18,   18, 1721, 1711, 1711,    0,    0,    0],
       [  18,   23,   23, 1721, 1729, 1729,    0,    0,    0],
       [  18,   23,   23, 1721, 1729, 1729,    0,    0,    0],
       [1721, 1721, 1721, 7371, 7011, 7011,    0,    0,    0],
       [1711, 1729, 1729, 7011, 7196, 7196,    0,    0,    0],
       [1711, 1729, 1729, 7011, 7196, 7196,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0]])