# Imports

In [1]:
import ipcoal
import toytree
import pandas as pd
import numpy as np
from ipcoal.smc.smc4 import get_embedded_gene_tree_table
from tqdm.notebook import tqdm
import toyplot
from itertools import combinations

# Functions

In [2]:
def p_ik(i, k, table):
    '''
    Recycled in the math, for getting coal chances for intervals between i and k
    '''
    # Special case if i and k are equal:
    if i == k:
        # -1/a_i
        first=-1/table.iloc[i].nedges
        # e^{-a_i*T_i/n_i}
        second=np.exp((-table.iloc[i].nedges/(table.iloc[i].neff))*table.iloc[i].stop)
        return(first*second)
    
    # If i and k are not equal:
    else:
        # -(a_i/n_i)*sigma_{i+1}
        term1 = -(table.iloc[i].nedges/(table.iloc[i].neff))*table.iloc[i].stop
        
        # sum from q=i+1 to k of: (a_q/n_q)*T_q
        term2 = 0
        for q in range(i+1,k):
            term2 += (table.iloc[q].nedges/(table.iloc[q].neff))*table.iloc[q].dist
            
        # First half of the equation
        firsthalf = np.exp(term1 - term2)

        # Second half of the equation
        # 1/a_k * (1-e^{-(a_k/n_k)*T_k})
        secondhalf = (1/table.iloc[k].nedges) * (1-np.exp(-(table.iloc[k].nedges/table.iloc[k].neff)*table.iloc[k].dist))

        return(firsthalf*secondhalf)


def pb1(i, table, m, I_b, I_bc):
    '''
    pb1(i), defined in the math
    '''
    ## Save i-related parameters
    # a_i
    curr_ai = table.iloc[i].nedges
    # T_i
    curr_Ti = table.iloc[i].stop-table.iloc[i].start
    # n_i
    curr_ni = table.iloc[i].neff
    # sigma_{i+1}
    curr_alpha_i1 = table.iloc[i].stop
    # sigma_i
    curr_alpha_i = table.iloc[i].start
    
    # Start with the second half of term 2...
    sum1 = 0
    for k in range(i,I_bc):
        sum1 += p_ik(i,k,table)
    sum2 = 0
    for k in range(m, I_b):
        sum2 += p_ik(i,k,table)
    
    # ...and now do the first half of term 2...
    firsthalf = (np.exp(curr_ai*curr_alpha_i1/curr_ni) - np.exp(curr_ai*curr_alpha_i/curr_ni))*curr_ni
    
    # ...and now multiply them together.
    second_term = firsthalf * (sum1+sum2)
    
    return((1/curr_ai) * (curr_Ti+second_term)) 


def pb2(i, table, I_b, I_bc):
    '''
    pb2(i), from the math
    '''
    ## Save i-related parameters
    # a_i
    curr_ai = table.iloc[i].nedges
    # T_i
    curr_Ti = table.iloc[i].stop-table.iloc[i].start
    # n_i
    curr_ni = table.iloc[i].neff
    # sigma_{i+1}
    curr_alpha_i1 = table.iloc[i].stop
    # sigma_i
    curr_alpha_i = table.iloc[i].start
    
    # Start with the second half of term 2...
    sum1 = 0
    for k in range(i,I_b):
        sum1 += p_ik(i,k,table)
    sum2 = 0
    for k in range(I_b,I_bc):
        sum2 += p_ik(i,k,table)
    
    # ...and now do the first half of term 2...
    firsthalf = (np.exp(curr_ai*curr_alpha_i1/curr_ni) - np.exp(curr_ai*curr_alpha_i/curr_ni))*curr_ni
    
    # ...and now multiply them together.
    second_term = firsthalf * (2*sum1+sum2)
    
    return((1/curr_ai) * (2*curr_Ti+second_term))
    
    
def get_intervals(gnode, treetable, SPTREE):
    '''
    For a given gene tree node, this returns the intervals from the main treetable. 
    '''
    # grab the parent node
    parent = gnode.up
    
    # if the current node is internal, get its first interval using the treetable
    if not gnode.is_leaf():
        prebegin_int = treetable.loc[treetable.coal.eq(gnode.idx)]
        begin_int = treetable.loc[treetable.start.eq(prebegin_int.stop.iloc[0])]
    
    # if the current node is a leaf, get its first interval using the species tree
    else:
        sptree_idx = SPTREE.get_mrca_node(*gnode.get_leaf_names()).idx
        begin_int = treetable.loc[treetable.st_node.eq(sptree_idx)]

    # each to find the end interval -- get the value of the tree table where the parent index is found
    end_int = treetable.loc[treetable.coal.eq(parent.idx)]

    # get the list of species tree nodes spanned by this genealogy branch
    cycle_node = SPTREE.get_nodes(int(begin_int.st_node.iloc[0]))[0]

    
    st_nodes = [cycle_node.idx]
    while cycle_node.idx != end_int.st_node.iloc[0]:
        cycle_node = cycle_node.up
        st_nodes.append(cycle_node.idx)
        
    subtable = treetable.loc[treetable['st_node'].isin(st_nodes)]
    subtable = subtable.loc[begin_int.index[0]:end_int.index[0]]
    
    return(subtable.reset_index(drop=True))

def get_unchange_prob(SPTREE, GTREE, IMAP):
    treetable = ipcoal.smc.smc4.get_embedded_gene_tree_table(SPTREE, GTREE, IMAP)
    treetable.neff = treetable.neff*2

    gtree_total_length = np.sum(GTREE.get_node_data().dist.iloc[:-1])

    total_prob = 0
    for gnode in GTREE.treenode.traverse(strategy='postorder'):
        if not gnode.is_root():
            # get intervals for the current branch
            gnode_ints = get_intervals(gnode, treetable, SPTREE)
            ######################
            #gnode_ints = get_embedded_path_of_gene_tree_edge(treetable,SPTREE,GTREE,IMAP,gnode.idx)
            ######################

            gnode_st_nodes = gnode_ints.st_node

            # get intervals for the parent branch
            parent = gnode.up
            if not parent.is_root():
                parent_ints = get_intervals(parent,treetable,SPTREE)
                ###################
                #parent_ints = get_embedded_path_of_gene_tree_edge(treetable,SPTREE,GTREE,IMAP,parent.idx)
                ###################
                parent_ints.neff = parent_ints.neff
            else:
                parent_ints = pd.DataFrame([gnode.up.height, 
                                     gnode.up.height + 1e9, # giant number here, infinite root branch length
                                     gnode_ints.iloc[-1].st_node,
                                     gnode_ints.iloc[-1].neff,
                                     1,
                                     np.nan,
                                     1e9],index=['start','stop','st_node','neff','nedges','coal','dist']).T

            # get index of sibling, and its intervals
            sib = list(set(parent.children).difference(set([gnode])))[0]
            sib_ints = get_intervals(sib,treetable,SPTREE)
            ##################
            #sib_ints = get_embedded_path_of_gene_tree_edge(treetable,SPTREE,GTREE,IMAP,sib.idx)
            ##################
            
            # get shared intervals of sibling
            # by asking which sib_ints are in the same species tree branch
            # (and later pruning to those which exist at the same time as gnode)
            sib_shared = sib_ints.loc[[i in np.array(gnode_st_nodes) for i in sib_ints.st_node]]
            
            # get important times
            # time at which sharing starts
            t_mb = sib_shared.iloc[0].start

            # time at which branch ends
            t_ub = gnode_ints.stop.iloc[-1]
            
            # time at which branch starts
            t_lb = gnode_ints.start.iloc[0]

            # in case the sibling branch starts earlier in sp tree branch than gnode branch
            if t_mb < t_lb:
                t_mb = t_lb
                sib_shared = gnode_ints.copy()
            
            # merging the current branch and parent branch dataframes
            merged_ints = pd.concat([gnode_ints,parent_ints],ignore_index=True)

            # get first interval index that is shared with sibling
            m = merged_ints.loc[merged_ints.stop > t_mb].index[0]

            # get number of intervals in current branch
            I_b = len(gnode_ints)

            # get number of intervals in combined branches
            I_bc = len(merged_ints)
            
            # get first summation, using pb1
            firstsum = 0
            for i in range(0,m):
                firstsum += pb1(i,merged_ints,m,I_b,I_bc)

            # get second summation, using pb2
            secsum = 0
            for i in range(m,I_b):
                secsum += pb2(i,merged_ints,I_b,I_bc)

            # get normalizer using branch stop/start times
            normalize = 1/(t_ub-t_lb)

            #get probability of topology not changing if recomb event falls on this branch
            topo_unchanged_prob = normalize*(firstsum + secsum)
            
            #print(topo_unchanged_prob)

            # contribute to total probability of unchanged genealogical topology
            total_prob += ((t_ub-t_lb)/gtree_total_length)*topo_unchanged_prob
    return(total_prob)

def prob_unchange_tree(gtre, sptre,IMAP):
    totaled_probs = 0
    
    gtree_total_length = np.sum(gtre.get_node_data().dist.iloc[:-1])
    
    #treetable = ipcoal.smc.smc4.get_embedded_gene_tree_table(gtre, sptre, IMAP)
    treetable = get_embedded_gene_tree_table(sptre,gtre,IMAP)
    treetable.neff = treetable.neff*2
    total_prob = 0
    all_int_tables = []
    for gnode in gtre.treenode.traverse(strategy='postorder'):
        if not gnode.is_root():
            # get intervals for the current branch
            
            #gnode_ints = get_embedded_path_of_gene_tree_edge(treetable,sptre,gtre,IMAP,gnode.idx)
            gnode_ints = get_intervals(gnode, treetable, sptre)
            gnode_ints.neff = gnode_ints.neff
            all_int_tables.append(gnode_ints)
            # time at which branch ends
            t_ub = gnode_ints.stop.iloc[-1]
            # time at which branch starts
            t_lb = gnode_ints.start.iloc[0]

            # get number of intervals in current branch
            I_b = len(gnode_ints)

            sumval = 0
            for i in range(I_b):
                curr_ai = gnode_ints.iloc[i].nedges
                curr_Ti = gnode_ints.iloc[i].dist
                curr_ni = gnode_ints.iloc[i].neff
                curr_alpha_i1 = gnode_ints.iloc[i].stop
                curr_alpha_i = gnode_ints.iloc[i].start

                first=(1/curr_ai)*curr_Ti
                second=(curr_ni/curr_ai)
                third=np.exp((curr_ai/curr_ni)*curr_alpha_i1)-np.exp((curr_ai/curr_ni)*curr_alpha_i)
                fourth=0
                for k in range(i,I_b):
                    fourth += p_ik(i,k,gnode_ints)
                sumval += first + second*third*fourth

            brprob = sumval * (1/(t_ub-t_lb))

            totaled_probs += ((t_ub-t_lb) / gtree_total_length) * brprob
    return(totaled_probs)


def make_binary(ttre):
    '''
    input: a toytree with unary nodes
    e.g. from a newick output by msprime with `record_full_arg=True`
    '''
    # List to hold node idxs to be collapsed
    collapse_idxs = []
    # For each node...
    for node in ttre.traverse():
        # If the node has fewer than two children...
        if len(node.children) < 2:
            # ...and if the node is node a leaf...
            if not node.is_leaf():
                # ...add that node idx to the list to be collapsed!
                collapse_idxs.append(node.idx)
    
    # Collapse the desired nodes and give us a new tree.
    tnew = ttre.mod.collapse_nodes(*collapse_idxs)
    return(tnew)


def test_rooted_topos_equal(t1,t2):
    '''
    Test if the *rooted* topologies of two toytrees are equal.
    '''
    # Record the tip labels for the trees
    tips = t1.get_tip_labels()

    # Initialize number to count the differences...
    different = 0
    
    # Iterate through combinations of two tips
    for pair in combinations(tips,2):
        # On both trees, get the mrca node for the current two tips,
        # and save all of its descendant leaves
        t1leaves_under_mrca = t1.get_mrca_node(*pair).get_leaf_names()
        t2leaves_under_mrca = t2.get_mrca_node(*pair).get_leaf_names()
        
        # Check if this mrca node is the same for both trees
        if not list(np.sort(t1leaves_under_mrca)) == list(np.sort(t2leaves_under_mrca)):
            different += 1
    # If there are any differences in pairwise mrcas, then the tree topos are different.
    if different:
        return(False)
    # Otherwise, the tree topos are the same.
    else:
        return(True)

# Show that for a given recombination event, we can calculate the probability that the tree topology will change.
# (This requires `ipcoal` to run with a "record_full_arg=True")

In [3]:
# init list to hold the OBSERVED result of whether first recomb results in a TREE change
treq_list = []
# init list to hold the OBSERVED result of whether first recomb results in a TREE change
toeq_list = []
# init list to hold the EXPECTED result of whether first recomb results in a TREE change
trprobs = []
# init list to hold the EXPECTED result of whether first recomb results in a TREE change
toprobs = []

# Specify number of replicates
nreps = 10000
for looper in tqdm(range(nreps)):

    ### SPECIES TREE MODEL
    
    # Specify range of Ne values
    # to be drawn for each branch.
    ne_min = 10000
    ne_max = 1000000
    
    
    # Define a species tree topology
    st= toytree.rtree.bdtree(ntips=7, seed=1234)
    # Scale the root height to something reasonable
    st = st.mod.edges_scale_to_root_height(2e6)
    # Set a random Ne value on each branch
    st = st.set_node_data("Ne",
                          {i:np.random.randint(ne_min,ne_max) for i in range(st.nnodes)},
                          default=150000
                         )

    
    ### IPCOAL MODEL
    
    # Set the recombination rate
    recomb = 1e-9
    
    # Define the ipcoal model using species tree, recomb rate, and taking Ne values
    mod = ipcoal.Model(st, 
                       recomb=recomb, 
                       Ne=None,
                       nsamples=1
                      )
    
    ### SIMULATE 
    
    # Simulate a locus long enough to expect at least 1 recombination event...
    mod.sim_trees(nloci=1,nsites=2000)
    

    ### EVALUATE
    
    # Save the initial genealogy
    gt = toytree.tree(mod.df.genealogy[0])

    # Save IMAP, mapping individuals to species 
    # (used by treetable function)
    imap = mod.get_imap_dict()
    

    # Record the PROBABILITY the first recomb resulting in a TREE change
    trprobs.append(prob_unchange_tree(make_binary(gt),st,imap))
    # Record the PROBABILITY the first recomb resulting in a TOPOLOGY change
    toprobs.append(get_unchange_prob(st,make_binary(gt),imap))

    # Access the first tree
    first_tree = make_binary(toytree.tree(mod.df.genealogy[1]))
    # Access the second tree
    second_tree = make_binary(toytree.tree(mod.df.genealogy[2]))

    # Record the OBSERVATION of whether the first recomb resulted in a TREE change
    treq_list.append(np.alltrue(first_tree.get_node_data() == second_tree.get_node_data()))
    # Record the OBSERVATION of whether the first recomb resulted in a TOPOLOGY change
    toeq_list.append(test_rooted_topos_equal(first_tree,second_tree))


  0%|          | 0/10000 [00:00<?, ?it/s]

In [4]:
print(np.mean(np.random.binomial(1,toprobs)))
print(np.mean(np.array(toeq_list)))
print(np.mean(np.random.binomial(1,trprobs)))
print(np.mean(np.array(treq_list)))

0.7135
0.7155
0.3746
0.3742


# Now show that we can recover expected wait times that match the ipcoal observations
# (Now running `ipcoal` with `record_full_arg=False`)

In [3]:
# Init list to hold the OBSERVED number of bps until TOPO change
nbps_list = []
# Init list to hold the OBSERVED number of bps until TREE change
tr_nbps_list = []
# Init list to hold the EXPECTED number of bps until TOPO change
exp_list = []
# Init list to hold the EXPECTED number of bps until TREE change
tr_exp_list = []

# Specify number of replicates
nreps = 10000
for rep in tqdm(range(nreps)):
    
    ### SPECIES TREE MODEL
    
    # Specify range of Ne values
    # to be drawn for each branch.
    ne_min = 10000
    ne_max = 1000000
    
    # Define a species tree topology
    st= toytree.rtree.bdtree(ntips=7, seed=1234)
    # Scale the root height to something reasonable
    st = st.mod.edges_scale_to_root_height(2e6)
    # Set a random Ne value on each branch
    st = st.set_node_data("Ne",
                          {i:np.random.randint(ne_min,ne_max) for i in range(st.nnodes)},
                          default=150000
                         )

    ### IPCOAL MODEL
    
    # Set the recombination rate
    recomb = 1e-9
    
    # Define the ipcoal model using species tree, recomb rate, and taking Ne values
    mod = ipcoal.Model(st, 
                       recomb=recomb, 
                       Ne=None,
                       nsamples=1
                      )
    
    ### SIMULATE 
    
    # Simulate a locus long enough to expect at least 1 topology change...
    mod.sim_trees(nloci=1,nsites=20000)
    
    ### EVALUATE
    
    # Save the initial genealogy
    starting_topo = toytree.tree(mod.df.genealogy[0])
    
    # Count the number of base pairs until topology change 
    
    rowidx = 0
    nbps = 0
    # "While the current genealogy's topology matches that of the starting topology, 
    # add the number of base pairs spanned by it.""
    while test_rooted_topos_equal(toytree.tree(mod.df.iloc[rowidx].genealogy), starting_topo):
        nbps += mod.df.iloc[rowidx].nbps
        rowidx += 1
        # If we are at the end and there's been no topology change, then stop iterating.
        if rowidx == len(mod.df):
            break
            
    # "If there was a topology change..."
    if rowidx != len(mod.df):
        
        # Record the OBSERVED number of base pairs until the first TREE change
        tr_nbps_list.append(mod.df.iloc[0].nbps)
        # Record the OBSERVED number of base pairs until the first TOPOLOGY change
        nbps_list.append(nbps)

        # Save IMAP, mapping individuals to species 
        # (used by treetable function)
        imap = mod.get_imap_dict()
        
        # Calculate the total length of the starting genealogy
        gtree_tot_len = np.sum(starting_topo.get_node_data().dist.iloc[:-1])
        
        # Calculate the probability of no TREE change, given a recomb event. 
        trunch_prob = prob_unchange_tree(starting_topo,st,imap)
        
        # Calculate the probability of no TOPOLOGY change, given a recomb event.
        unch_prob = get_unchange_prob(st,starting_topo,imap)
        
        # Calculate the probability of a TOPOLOGY change, given a recomb event.
        BsT = 1-unch_prob
        
        # Calculate rate of TOPOLOGY change
        lam = recomb*BsT*gtree_tot_len
        # Record the EXPECTED number of base pairs until the first TOPOLOGY change
        exp_list.append(1/lam)
        
        # Calculate the rate of TREE change
        lam = recomb*(1-trunch_prob)*gtree_tot_len
        # Record the EXPECTED number of base pairs until the first TREE change
        tr_exp_list.append(1/lam)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [5]:
print(np.mean(np.array(tr_nbps_list) / np.array(tr_exp_list)))
print(np.mean(np.array(nbps_list) / np.array(exp_list)))

1.012860671757952
1.0169630747594829
