In [582]:
from trails.load_trans_mat import load_trans_mat, trans_mat_num
from ast import literal_eval
from trails.combine_states import combine_states
from scipy.linalg import expm
import numpy as np
from trails.cutpoints import cutpoints_AB, cutpoints_ABC


First, we can define the demographic process:

In [583]:
coal = 1
rho = 0.0008
# Time from present to migration event
t_m = 0.4
# Time from migration event to first speciation event
t_1 = 0.2
# Time between speciation events
t_2 = 0.3
# Migration rate
m = 0.2

Now, we can calculate the transition rate matrices and the state space for each of the species:

In [584]:
# Load string transition rate matrix and convert string names to actual lists
(trans_mat_1, state_space_1) = load_trans_mat(1)
state_space_A = [literal_eval(i) for i in state_space_1]
(trans_mat_2, state_space_2) = load_trans_mat(2)
state_space_AB = [literal_eval(i) for i in state_space_2]
(trans_mat_3, state_space_3) = load_trans_mat(3)
state_space_ABC = [literal_eval(i) for i in state_space_3]

state_space_B = []
for j in state_space_A:
    lst = []
    for i in j:
        one = 2 if i[0] == 1 else i[0]
        two = 2 if i[1] == 1 else i[1]
        lst.append((one, two))
    state_space_B.append(lst)

state_space_C = []
for j in state_space_A:
    lst = []
    for i in j:
        one = 4 if i[0] == 1 else i[0]
        two = 4 if i[1] == 1 else i[1]
        lst.append((one, two))
    state_space_C.append(lst)

In [585]:
trans_mat_A = trans_mat_num(trans_mat_1, coal, rho)
trans_mat_B = trans_mat_num(trans_mat_1, coal, rho)
trans_mat_C = trans_mat_num(trans_mat_1, coal, rho)
trans_mat_AB = trans_mat_num(trans_mat_2, coal, rho)


We can now calculate the transition probability matrix for each species at the time of migration:

In [586]:
final_A = expm(trans_mat_A*(t_m+t_1))[0]
final_B = expm(trans_mat_B*t_m)[0]
final_C = expm(trans_mat_C*t_m)[0]
final_C_bis = expm(trans_mat_C*(t_m+t_1+t_2))[0]
final_A_bis = expm(trans_mat_A*(t_m+t_1+t_2))[0]

The following function splits the vector of final probabilities for species B given a certain migration rate `m`. Based on the direction of choice, the resulting probabilities will correspond to the left path (where lineages do not migrate and will later mix with species A), or the right path (where lineages migrate and are instantly mixed with species C):

In [587]:
def split_migration(state_space, prob_vec, m, direction):
    x = prob_vec[1]
    st = [state_space[0], state_space[1], [state_space[1][0]], [state_space[1][1]]]
    if direction == 'left':
        pr = np.array([(1-x)*(1-m), (1-m)**2*x, 1/2*(1-m)*m*x, 1/2*(1-m)*m*x])
    if direction == 'right':
        pr = np.array([(1-x)*m, x*m**2, 1/2*(1-m)*m*x, 1/2*(1-m)*m*x])
    return (st, pr)

# def split_migration(state_space, prob_vec, m, direction):
#     x = prob_vec[1]
#     st = [state_space[0], state_space[1], [state_space[1][0]], [state_space[1][1]]]
#     if direction == 'left':
#         pr = np.array([(1-x)*(1-m), (1-m)**2*x, (1-m)*m*x, (1-m)*m*x])
#     if direction == 'right':
#         pr = np.array([(1-x)*m, x*m**2, (1-m)*m*x, (1-m)*m*x])
#     return (st, pr)


In [588]:
(state_space_B_left, final_B_left) = split_migration(state_space_B, final_B, m, 'left')
(state_space_B_right, final_B_right) = split_migration(state_space_B, final_B, m, 'right')
sum(final_B_left)+sum(final_B_right)

1.0

## Right path, migration event to second speciation event 

### Full lineages

For the right path, we can combine the state space of the migrated B lineages with the state space of C:

In [589]:
(comb_BC_name_full, comb_BC_value_full) = combine_states(state_space_A, state_space_B_right[0:2], final_C, final_B_right[0:2])
# This is with the correct ordering
pi_BC_full = [comb_BC_value_full[comb_BC_name_full.index(i)] if i in comb_BC_name_full else 0 for i in state_space_2]
final_BC_full = pi_BC_full @ expm(trans_mat_AB*(t_1+t_2))
final_BC_full

array([6.20896446e-08, 2.12027948e-04, 1.72420469e-08, 1.72420469e-08,
       2.12027948e-04, 6.05916214e-01, 9.21199745e-09, 4.41928205e-08,
       4.41928205e-08, 9.51201559e-05, 9.51201559e-05, 9.51201559e-05,
       9.51201559e-05, 7.21362624e-05, 3.93206920e-01])

In [590]:
state_space_BC = [j.replace('1', '4') for j in state_space_2]
state_space_BC = [j.replace('3', '6') for j in state_space_BC]
state_space_BC = [literal_eval(i) for i in state_space_BC]
state_space_BC = [sorted(i) for i in state_space_BC]
state_space_BC

[[(0, 2), (0, 4), (2, 0), (4, 0)],
 [(0, 2), (2, 0), (4, 4)],
 [(0, 2), (2, 4), (4, 0)],
 [(0, 4), (2, 0), (4, 2)],
 [(0, 4), (2, 2), (4, 0)],
 [(2, 2), (4, 4)],
 [(2, 4), (4, 2)],
 [(0, 6), (2, 0), (4, 0)],
 [(0, 2), (0, 4), (6, 0)],
 [(2, 0), (4, 6)],
 [(0, 2), (6, 4)],
 [(2, 6), (4, 0)],
 [(0, 4), (6, 2)],
 [(0, 6), (6, 0)],
 [(6, 6)]]

### Missing lineages

In [591]:
def load_trans_mat_miss(i):
    if i == 1:
        mat = np.array([['0', '0', 'C', '0'],
                     ['0', '0', '0', 'C'],
                     ['R', '0', '0', '0'],
                     ['0', 'R', '0', '0']], dtype=object)
        st = ['[(0, 1), (0, 2), (1, 0)]', '[(0, 1), (1, 0), (2, 0)]', '[(0, 2), (1, 1)]', '[(1, 1), (2, 0)]']
    if i == 2:
        mat = np.array(
            [
                ['0', 'R', '0', '0', 'C', '0', '0', '0', '0', '0'],
                ['C', '0', 'C', 'C', '0', '0', '0', '0', '0', '0'],
                ['0', 'R', '0', '0', 'C', '0', '0', '0', '0', '0'],
                ['0', '0', '0', '0', 'C', '0', '0', '0', '0', '0'],
                ['0', '0', '0', 'R', '0', '0', '0', '0', '0', '0'],
                ['0', '0', '0', '0', '0', '0', 'R', '0', '0', 'C'],
                ['0', '0', '0', '0', '0', 'C', '0', 'C', 'C', '0'],
                ['0', '0', '0', '0', '0', '0', 'R', '0', '0', 'C'],
                ['0', '0', '0', '0', '0', '0', '0', '0', '0', 'C'],
                ['0', '0', '0', '0', '0', '0', '0', '0', 'R', '0']
            ], 
            dtype=object)
        st = [
            '[(2, 0), (4, 4)]', 
            '[(0, 4), (2, 0), (4, 0)]', 
            '[(2, 4), (4, 0)]', 
            '[(0, 4), (6, 0)]',
            '[(6, 4)]',
            '[(0, 2), (4, 4)]', 
            '[(0, 2), (0, 4), (4, 0)]', 
            '[(0, 4), (4, 2)]', 
            '[(0, 6), (4, 0)]',
            '[(4, 6)]']
    return (mat, st)
(trans_mat_2_miss, state_space_2_miss) = load_trans_mat_miss(2)
state_space_BC_miss = [literal_eval(i) for i in state_space_2_miss]
trans_mat_BC_miss = trans_mat_num(trans_mat_2_miss, coal, rho)

In [592]:
(comb_BC_name_miss, comb_BC_value_miss) = combine_states(state_space_C, state_space_B_right[2::], final_C, final_B_right[2::])
# This is with the correct ordering
pi_BC_miss = [comb_BC_value_miss[comb_BC_name_miss.index(i)] if i in comb_BC_name_miss else 0 for i in state_space_2_miss]
final_BC_miss = pi_BC_miss @ expm(trans_mat_BC_miss*(t_1+t_2))
final_BC_miss


array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [593]:
sum(final_BC_full) + sum(final_BC_miss)

1.0000000000000004

## Left path, migration to second speciation event

### Full lineages

In [594]:
final_B_left_full = final_B_left[0:2] @ expm(trans_mat_B*t_1)
(comb_AB_name_full, comb_AB_value_full) = combine_states(state_space_A, state_space_B_left[0:2], final_A, final_B_left_full)
# This is with the correct ordering
pi_AB_full = [comb_AB_value_full[comb_AB_name_full.index(i)] if i in comb_AB_name_full else 0 for i in state_space_2]

In [595]:
final_AB_full = pi_AB_full @ expm(trans_mat_AB*t_2)

### Missing lineages

In [596]:
(trans_mat_4, state_space_4) = load_trans_mat_miss(2)
trans_mat_AB_miss = trans_mat_num(trans_mat_4, coal, rho)
state_space_4 = [i.replace('4', '1') for i in state_space_4]
state_space_4 = [i.replace('6', '3') for i in state_space_4]
state_space_AB_miss = [sorted(literal_eval(i)) for i in state_space_4]
state_space_4 = [str(i) for i in state_space_AB_miss]
state_space_AB_miss

[[(1, 1), (2, 0)],
 [(0, 1), (1, 0), (2, 0)],
 [(1, 0), (2, 1)],
 [(0, 1), (3, 0)],
 [(3, 1)],
 [(0, 2), (1, 1)],
 [(0, 1), (0, 2), (1, 0)],
 [(0, 1), (1, 2)],
 [(0, 3), (1, 0)],
 [(1, 3)]]

In [597]:
(comb_AB_name_miss, comb_AB_value_miss) = combine_states(state_space_A, state_space_B_left[2::], final_A, final_B_left[2::])
pi_AB_miss = [comb_AB_value_miss[comb_AB_name_miss.index(i)] if i in comb_AB_name_miss else 0 for i in state_space_4]
final_AB_miss  = pi_AB_miss @ expm(trans_mat_AB_miss*(t_1))

In [598]:
sum(final_AB_full) + sum(final_AB_miss)

0.0

## Mixing probabilities for deep coalescence

In [599]:
sum(final_BC_miss)+sum(final_AB_miss)

0.0

In [600]:
lst_a = []
lst_b = []

In [601]:
(a, b) = combine_states(
    state_space_AB_miss[5::], state_space_BC_miss[0:5], 
    final_AB_miss[5::], final_BC_miss[0:5]/sum(final_BC_miss[0:5]))
lst_a = lst_a+a
lst_b = lst_b+b
a = [literal_eval(i) for i in a]
a = [sorted(i) for i in a]
sum(b)

  This is separate from the ipykernel package so we can avoid doing imports until


0.0

In [602]:
(a, b) = combine_states(
    state_space_AB_miss[5::], state_space_BC_miss[0:5], 
    final_AB_miss[5::]/sum(final_AB_miss[5::]), final_BC_miss[0:5])
lst_a = lst_a+a
lst_b = lst_b+b
a = [literal_eval(i) for i in a]
a = [sorted(i) for i in a]
sum(b)

  This is separate from the ipykernel package so we can avoid doing imports until


0.0

In [603]:
(a, b) = combine_states(
    state_space_AB_miss[0:5], state_space_BC_miss[5::], 
    final_AB_miss[0:5], final_BC_miss[5::]/sum(final_BC_miss[5::]))
lst_a = lst_a+a
lst_b = lst_b+b
a = [literal_eval(i) for i in a]
a = [sorted(i) for i in a]
sum(b)

  This is separate from the ipykernel package so we can avoid doing imports until


0.0

In [604]:
(a, b) = combine_states(
    state_space_AB_miss[0:5], state_space_BC_miss[5::], 
    final_AB_miss[0:5]/sum(final_AB_miss[0:5]), final_BC_miss[5::])
lst_a = lst_a+a
lst_b = lst_b+b
a = [literal_eval(i) for i in a]
a = [sorted(i) for i in a]
sum(b)

  This is separate from the ipykernel package so we can avoid doing imports until


0.0

In [605]:
(a, b) = combine_states(state_space_AB, state_space_C, final_AB_full, final_C_bis)
lst_a = lst_a+a
lst_b = lst_b+b
a = [literal_eval(i) for i in a]
a = [sorted(i) for i in a]
sum(b)

0.0

In [606]:
(a, b) = combine_states(state_space_BC, state_space_A, final_BC_full, final_A_bis)
lst_a = lst_a+a
lst_b = lst_b+b
a = [literal_eval(i) for i in a]
a = [sorted(i) for i in a]
sum(b)

1.0000000000000002

In [607]:
sum(lst_b)

1.0000000000000002

In [608]:
dct = {}
for i in range(len(lst_a)):
    if lst_a[i] not in dct:
        dct[lst_a[i]] = lst_b[i]
    else:
        dct[lst_a[i]] += lst_b[i]

In [609]:
sum(dct.values())

1.0000000000000002

In [610]:
import pandas as pd

In [611]:
dat = pd.DataFrame()
dat['names'] = dct.keys()
dat['values'] = dct.values()
dat.sort_values('values')

Unnamed: 0,names,values
34,"[(0, 2), (0, 4), (3, 1), (4, 0)]",0.000000
23,"[(1, 3), (2, 4), (4, 0)]",0.000000
24,"[(1, 3), (6, 4)]",0.000000
25,"[(0, 1), (0, 2), (0, 4), (3, 0), (4, 0)]",0.000000
26,"[(0, 1), (0, 2), (3, 0), (4, 4)]",0.000000
...,...,...
12,"[(0, 2), (1, 1), (2, 0), (4, 4)]",0.000212
53,"[(0, 4), (1, 1), (2, 2), (4, 0)]",0.000212
48,"[(0, 1), (1, 0), (2, 2), (4, 4)]",0.000288
69,"[(1, 1), (6, 6)]",0.393020


In [612]:
om = {}
flatten = [list(sum(i, ())) for i in state_space_ABC]    
for l in [0, 3, 5, 6, 7]:
    for r in [0, 3, 5, 6, 7]:
        if (l in [3, 5, 6, 7]) and (r in [3, 5, 6, 7]):
            om['%s%s' % (l, r)] = [i for i in range(203) if (l in flatten[i][::2]) and (r in flatten[i][1::2])]
        elif (l == 0) and (r in [3, 5, 6, 7]):
            om['%s%s' % (l, r)] = [i for i in range(203) if (all(x not in [3, 5, 6, 7] for x in flatten[i][::2])) and (r in flatten[i][1::2])]
        elif (l  in [3, 5, 6, 7]) and (r == 0):
            om['%s%s' % (l, r)] = [i for i in range(203) if (l in flatten[i][::2]) and (all(x not in [3, 5, 6, 7] for x in flatten[i][1::2]))]
        elif l == r == 0:
            om['%s%s' % (l, r)] = [i for i in range(203) if all(x not in [3, 5, 6, 7] for x in flatten[i])]
omega_tot_ABC = [i for i in range(203)]
om['71'] = sorted(om['73']+om['75']+om['76'])
om['17'] = sorted(om['37']+om['57']+om['67'])
om['10'] = sorted(om['30']+om['50']+om['60'])
om['13'] = sorted(om['33']+om['53']+om['63'])
om['15'] = sorted(om['35']+om['55']+om['65'])
om['16'] = sorted(om['36']+om['56']+om['66'])
om['11'] = sorted(om['13']+om['15']+om['16'])

dct_num = {3:1, 5:2, 6:3}

In [613]:
[state_space_ABC[i] for i in om['06']]

[[(0, 1), (0, 6), (1, 0), (2, 0), (4, 0)],
 [(0, 6), (1, 1), (2, 0), (4, 0)],
 [(0, 6), (1, 0), (2, 1), (4, 0)],
 [(0, 6), (1, 0), (2, 0), (4, 1)],
 [(0, 1), (1, 6), (2, 0), (4, 0)],
 [(0, 1), (1, 0), (2, 6), (4, 0)],
 [(0, 1), (1, 0), (2, 0), (4, 6)],
 [(1, 1), (2, 6), (4, 0)],
 [(1, 1), (2, 0), (4, 6)],
 [(1, 6), (2, 1), (4, 0)],
 [(1, 0), (2, 1), (4, 6)],
 [(1, 6), (2, 0), (4, 1)],
 [(1, 0), (2, 6), (4, 1)]]

In [614]:
ordered_pi_ABC = [list(dct.values())[list(dct.keys()).index(str(i))] if str(i) in list(dct.keys()) else 0 for i in state_space_ABC]

In [615]:
sum(np.array(ordered_pi_ABC)[om['36']])

0.0

In [616]:
sum(np.array(ordered_pi_ABC)[om['63']])

0.0

In [617]:
sum(np.array(ordered_pi_ABC)[om['66']])

0.3932790557827152

In [618]:
sum(np.array(ordered_pi_ABC)[om['33']])

0.0

In [619]:
sum(list(dct.values()))

1.0000000000000002

# Getting the joint probability table

In [620]:
n_int_AB = 1
n_int_ABC = 1

# Create empty table for the joint probabilities
tab = np.zeros((9, 203))
# Create empty vector for the names of the states
tab_names = []
# Create accumulator for keeping track of the indices for the table
acc = 0

tab_names.append(('D', 'D')) 
tmp_lst = om['00']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(((4, 0), (4, 0))) 
tmp_lst = om['66']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(((0, 0), (0, 0))) 
tmp_lst = om['33']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(('D', (4, 0)))
tmp_lst = om['06']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(((4, 0), 'D'))
tmp_lst = om['60']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(('D', (0, 0)))
tmp_lst = om['03']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(((0, 0), 'D'))
tmp_lst = om['30']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(((4, 0), (0, 0)))
tmp_lst = om['63']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1

tab_names.append(((0, 0), (4, 0)))
tmp_lst = om['36']
tab[acc] = [ordered_pi_ABC[i] if i in tmp_lst else 0 for i in range(len(state_space_ABC))]
acc += 1


In [621]:
tab.sum()

1.0

In [622]:
sum(ordered_pi_ABC)

1.0000000000000004

In [623]:
from function_file import get_tab_ABC_introgression
n_int_ABC = 1
n_int_AB = 1
trans_mat_ABC = trans_mat_num(trans_mat_3, coal, rho)
cut_ABC = cutpoints_ABC(n_int_ABC, coal)
cut_ABC

array([ 0., inf])

In [624]:
joint_mat = get_tab_ABC_introgression(state_space_ABC, trans_mat_ABC, cut_ABC, tab, tab_names, n_int_AB)

In [625]:
joint_mat[:,2].sum()

1.0

In [626]:
joint_mat

array([[(0, 0, 0), (0, 0, 0), 0.0],
       [(4, 0, 0), (4, 0, 0), 0.3932790557827152],
       [(0, 0, 0), (4, 0, 0), 0.0],
       [(4, 0, 0), (0, 0, 0), 0.0],
       [(0, 0, 0), (1, 0, 0), 0.0],
       [(1, 0, 0), (0, 0, 0), 0.0],
       [(4, 0, 0), (1, 0, 0), 6.342816821714624e-05],
       [(1, 0, 0), (4, 0, 0), 6.342816821714624e-05],
       [(0, 0, 0), (2, 0, 0), 0.0],
       [(2, 0, 0), (0, 0, 0), 0.0],
       [(4, 0, 0), (2, 0, 0), 6.342816821714624e-05],
       [(2, 0, 0), (4, 0, 0), 6.342816821714624e-05],
       [(0, 0, 0), (3, 0, 0), 0.0],
       [(3, 0, 0), (0, 0, 0), 0.0],
       [(4, 0, 0), (3, 0, 0), 6.342816821714624e-05],
       [(3, 0, 0), (4, 0, 0), 6.342816821714624e-05],
       [(1, 0, 0), (1, 0, 0), 0.20193482976181462],
       [(1, 0, 0), (2, 0, 0), 9.14179066453401e-05],
       [(1, 0, 0), (3, 0, 0), 8.721073420070985e-05],
       [(2, 0, 0), (1, 0, 0), 9.14179066453401e-05],
       [(2, 0, 0), (2, 0, 0), 0.20193482976181462],
       [(2, 0, 0), (3, 0, 0), 8.72107

In [627]:
tr = pd.DataFrame(joint_mat, columns=['From', 'To', 'Prob']).pivot(index = ['From'], columns = ['To'], values = ['Prob'])
tr.columns = tr.columns.droplevel()
hidden_names = list(tr.columns)
hidden_names = dict(zip(range(len(hidden_names)), hidden_names))
arr = np.array(tr).astype(np.float64)
pi = arr.sum(axis=1)
a = arr/pi[:,None]

pd.DataFrame(joint_mat, columns=['From', 'To', 'Prob']).to_csv(f'joint_mat_{m}.csv', index = False)

  import sys


In [628]:
a

array([[           nan,            nan,            nan,            nan,
                   nan],
       [0.00000000e+00, 9.98802747e-01, 4.52167942e-04, 4.31358578e-04,
        3.13726110e-04],
       [0.00000000e+00, 4.52167942e-04, 9.98802747e-01, 4.31358578e-04,
        3.13726110e-04],
       [0.00000000e+00, 4.31358578e-04, 4.31358578e-04, 9.98823557e-01,
        3.13726110e-04],
       [0.00000000e+00, 1.61202314e-04, 1.61202314e-04, 1.61202314e-04,
        9.99516393e-01]])