# Getting the probability distribution of a change in the genealogical tree

We want to find the waiting distance to a change in a genealogy based on the current genealogy and on the species tree that contains it.

In [51]:
import toytree
import toyplot
import ipcoal
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
def get_tree_total_length(ttree):
    tot_len = 0
    for node_ in ttree.treenode.traverse():
        if not node_.is_root():
            tot_len += node_.dist
    return(tot_len)
def get_num_edges_at_time(tree, time):
    nodes_above = ([idx for idx, node in tree.idx_dict.items() if node.height > time])
    edges_above = len(nodes_above) + 1
    return edges_above
def get_tree_clade_times(tree):
    nodes_ = []
    heights_ = []
    for curr_node in tree.treenode.traverse():
        if not curr_node.is_leaf():
            nodes_.append(curr_node.get_leaf_names())
            heights_.append(curr_node.height)
    pddf = pd.DataFrame([nodes_,heights_],index=['clades','heights']).T
    return(pddf)
def get_branch_intervals(tr, gt, br):
    '''
    tr = species tree with Ne attribute
    gt = gene tree simulated on that species tree
    br = treenode representing a branch on the tree
    '''
    st_times = get_tree_clade_times(tr)
    gt_times = get_tree_clade_times(gt)
    coalclade = br.get_leaf_names()
    
    ###temp
    st_coal_node = tr.treenode.search_nodes(idx=tr.get_mrca_idx_from_tip_labels(br.get_leaf_names()))[0]
    nearest_st_node = st_coal_node
    while ((nearest_st_node.height + nearest_st_node.dist) < br.height):
        if nearest_st_node.is_root():
            break
        nearest_st_node = nearest_st_node.up

    coalclade = nearest_st_node.get_leaf_names()
    ###
    
    
    br_lower = br.height
    br_upper = br_lower + br.dist
    gt_clade_changes = (gt_times.heights < br_upper) & (gt_times.heights > br_lower)
    st_clade_changes = (st_times.heights < br_upper) & (st_times.heights > br_lower)
    st_time_diffed = st_times[st_clade_changes]
    #return(np.array([all(elem in clade for elem in coalclade) for clade in st_time_diffed.clades]))

    contains_clade = st_time_diffed[np.array([all(elem in clade for elem in coalclade) for clade in st_time_diffed.clades])]

    if not len(contains_clade.columns):
        contains_clade = pd.DataFrame(columns=['clades','heights'])
    contains_clade = pd.DataFrame([list(contains_clade.clades.append(pd.Series([coalclade]),ignore_index=True)),list(contains_clade.heights.append(pd.Series(br_lower)))],index=['clades','heights']).T
    contains_clade = contains_clade.sort_values('heights')

    all_members = []
    for i in contains_clade.clades:
        all_members.extend(i)
    all_members = np.unique(all_members)
    
    relevant_coals = pd.DataFrame(columns=["heights"])

    if np.sum(gt_clade_changes):
        potential_coals = gt_times[gt_clade_changes]
        relevant_coals = potential_coals[[set(i).issubset(all_members) for i in potential_coals.clades]]
        relevant_coals = relevant_coals.sort_values('heights')

    time_points = np.sort(list(contains_clade.heights) + list(relevant_coals.heights) + [br_upper])
    if int(time_points[-1]) == int(time_points[-2]):
        time_points = time_points[:-1]
    starts = time_points[:-1]
    stops = time_points[1:]
    lengths = stops-starts
    num_to_coal = np.repeat(1,len(starts))
    ne = np.repeat(1,len(starts))
    a_df = pd.DataFrame([starts,stops,lengths,num_to_coal,ne],index=['starts','stops','lengths','num_to_coal','ne']).T
    mids = (a_df.stops + a_df.starts)/2
    interval_reduced_trees=[]
    
    nes = []
    for mid in mids:
        clade = contains_clade.clades.iloc[np.sum(contains_clade.heights<mid)-1]

        cladeNe = tr.treenode.search_nodes(idx=tr.get_mrca_idx_from_tip_labels(clade))[0].Ne
        nes.append(cladeNe)
        reduced_tree = gt.prune(clade)
        interval_reduced_trees.append(reduced_tree.newick)
        
    a_df['reduced_trees'] = interval_reduced_trees
    a_df['mids'] = mids
    a_df['ne'] = nes
    a_df['num_to_coal'] = a_df.apply(lambda x: get_num_edges_at_time(toytree.tree(x['reduced_trees']), x['mids']), axis=1)
    
    return a_df

### Start with a species tree

In [2]:
# make a random tree
tre = toytree.rtree.bdtree(6,time=8e3,seed=12345)

In [3]:
# scale it so that branch lengths that make sense
tre = tre.mod.node_scale_root_height(treeheight=8e3)

### Set random Ne to each branch

In [4]:
# set a random Ne to each node
node_ne_dict = {i:np.random.randint(1,20000) for i in range(tre.nnodes)} # Ne drawn randomly between 1 and 20000
tre = tre.set_node_data('Ne',node_ne_dict)

In [5]:
tre.draw(ts='p',node_labels=True,node_sizes=15,width=500,height=500,node_mask=False);

### Now we can define an ipcoal model and simulate a gene tree

In [6]:
# define the model
mod = ipcoal.Model(tre,Ne=None,seed_trees=1235)
# simulate a gene tree
mod.sim_trees(1)

### Let's look at the gene tree:

In [7]:
# extract the gene tree individually
gtr = toytree.tree(mod.df.genealogy[0])
# draw it
gtr.draw(ts='p',node_labels=True,node_sizes=15,width=500,height=500,node_mask=False);

### Notice that this gene tree does not match the species tree because of the short (coalescent unit) branch lengths. 

# Demonstrate how we define intervals:

### grab a node of a specific index from the gene tree

In [8]:
# grab node 4
mybranch = gtr.treenode.search_nodes(idx=4)[0]
print(mybranch)


--r5


### get the different `num_to_coal` and `ne` intervals for this branch

The coalescent probabilities will be piecewise constant, based on this table.

In [9]:
df = get_branch_intervals(tre,gtr,mybranch)
df

Unnamed: 0,starts,stops,lengths,num_to_coal,ne,reduced_trees,mids
0,0.0,3720.299839,3720.299839,1,13685,r5:17201.6;,1860.14992
1,3720.299839,5896.649001,2176.349162,1,13685,r5:17201.6;,4808.47442
2,5896.649001,8000.0,2103.350999,1,13685,r5:17201.6;,6948.324501
3,8000.0,14624.489361,6624.489361,4,9600,"(r0:17201.6,(r5:15773.5,...",11312.24468
4,14624.489361,15773.536641,1149.04728,3,9600,"(r0:17201.6,(r5:15773.5,...",15199.013001


Notice there is some redundancy here. It doesn't matter, but might end up slowing down our computations a bit later on.

# Proposition 1

Given a dataframe for a branch, and a time for the recombination event, get the probability of the tree being unchanged. 

In [12]:
def calc_P_btT(t, df):
    interval_index = np.sum(df.starts < t) - 1
    last_index = len(df.starts)-1
    
    ai = df['num_to_coal'][interval_index]
    ni = df['ne'][interval_index]
    sigi = df['stops'][interval_index]
    
    first_term = (1/ai) - (1/ai)*np.exp(-1*(ai/ni)*sigi)*np.exp((ai/ni)*t)
    
    second_term = 0
    for int_idx in range(interval_index+1,last_index+1): # for the *full* intervals above t
        # start with the summation
        internal_summation = 0
        if int_idx - interval_index > 1:
            for q_idx in range(interval_index+1,int_idx):
                aq = df['num_to_coal'][q_idx]
                nq = df['ne'][q_idx]
                Tq = df['lengths'][q_idx]
                internal_summation += ((aq/nq)*Tq)

        # define the properties of the current interval
        aint = df['num_to_coal'][int_idx]
        nint = df['ne'][int_idx]
        Tint = df['lengths'][int_idx]

        # calculate the expressions that are multiplied together
        first_mult = np.exp((ai/ni)*t)
        second_mult = np.exp(-1*(ai/ni)*sigi - internal_summation)
        third_mult = (1/aint)*(1-np.exp(-1*(aint/nint)*Tint))

        second_term += first_mult*second_mult*third_mult
    return(first_term + second_term)

### Calculate it at a bunch of values of x along the branch.

In [24]:
xs = np.linspace(1, 15700,100)
ys = np.array([calc_P_btT(x,df) for x in xs])

In [28]:
toyplot.plot(xs,ys,label='prob of tree unchanging on branch b given time t',xlabel="time = t",ylabel="probability of no tree change");

# Proposition 2

We can calculate the unchanging probability at points (t) along a branch. Now we want to integrate through all times t.

In [29]:
def calc_P_bT(df):
    last_index = len(df.starts)-1

    full_branch_summation = 0
    full_branch_start = df['starts'][0]
    full_branch_stop = df['stops'][last_index]

    for interval_index in range(len(df)):
        ai = df['num_to_coal'][interval_index]
        ni = df['ne'][interval_index]*2######################
        sigi = df['stops'][interval_index]
        sigb = df['starts'][interval_index]
        Ti = df['lengths'][interval_index]

        first_term = (1/ai)*Ti

        second_expr_second_term = 0
        for int_idx in range(interval_index+1,last_index+1): # for the *full* intervals above t
            # start with the summation
            internal_summation = 0
            if int_idx - interval_index > 1:
                for q_idx in range(interval_index+1,int_idx):
                    aq = df['num_to_coal'][q_idx]
                    nq = df['ne'][q_idx]*2############################
                    Tq = df['lengths'][q_idx]
                    internal_summation += ((aq/nq)*Tq)

            # define the properties of the current interval
            aint = df['num_to_coal'][int_idx]
            nint = df['ne'][int_idx]*2####################################
            Tint = df['lengths'][int_idx]

            # calculate the expressions that are multiplied together
            #first_mult = np.exp((ai/ni)*t)
            second_mult = np.exp(-1*(ai/ni)*sigi - internal_summation)
            third_mult = (1/aint)*(1-np.exp(-1*(aint/nint)*Tint))

            #print(second_mult*third_mult)
            second_expr_second_term += (second_mult*third_mult)*(ni/ai)

        # preventing overflow
        if ((ai/ni)*sigi < 709) and ((ai/ni)*sigb < 709): # prevent overflow...
            second_expr_second_term += -np.exp(-1*(ai/ni)*sigi) * (ni/(ai*ai))
            first_expr_second_term = (np.exp((ai/ni)*sigi) - np.exp((ai/ni)*sigb))
        # if there is no internal summation, then the problem simplifies to (e^x-e^y)/e^x , which is 1-e^(y-x)
        elif second_expr_second_term == 0:
            second_expr_second_term +=1
            first_expr_second_term = (1-np.exp((ai/ni)*sigb-(ai/ni)*sigi))* (ni/(ai*ai))
        full_branch_summation += first_term + first_expr_second_term*second_expr_second_term
    return(full_branch_summation * (1/(full_branch_stop-full_branch_start)))

### Now let's see what the probability is across the whole branch (assuming uniform probability of recombination location)

In [30]:
calc_P_bT(df)

0.21833299291045577

# Loop across branches of the genealogy to get full probability of a recombination event not changing the tree

In [31]:
def get_unchange_prob(tre, gtr):
    full_tree_length = 0
    for node in gtr.treenode.traverse():
        if not node.is_root():
            full_tree_length += node.dist
    prob_tree_unchanged = 0
    for node in gtr.treenode.traverse():
        if not node.is_root():
            #print(node.idx)
            df = get_branch_intervals(tre,gtr,node)
            #full_branch_start = df['starts'][0]
            #full_branch_stop = df['stops'].iloc[-1]

            unchanged_branch_prob = calc_P_bT(df)
            #print(unchanged_branch_prob)

            prob_tree_unchanged += (node.dist / full_tree_length) * unchanged_branch_prob
    return(prob_tree_unchanged)

In [32]:
# probability that a recombination event does not change the tree
get_unchange_prob(tre,gtr)

0.1490408666745664

## Use this to get expected wait time to a tree change.

In [35]:
def get_lambda(tre,gtr,recomb_rate):
    bigl = get_tree_total_length(gtr)
    alpha = 1-get_unchange_prob(tre, gtr)
    rho_over_2 = recomb_rate
    lambda_ = rho_over_2*alpha*bigl
    return(lambda_)

In [36]:
# lambda (rate in exponential)
get_lambda(tre, gtr,1e-9)

6.332682579213137e-05

In [37]:
# mean of exponential is the expected wait time
1/get_lambda(tre, gtr,1e-9)

15791.096229621133

### Get probability of specifc wait time:

In [39]:
lam = get_lambda(tre, gtr,1e-9)
wait_time = 10000
lam*np.exp(-lam*wait_time)

3.361729846118402e-05

# Show wait time pdf

In [44]:
toyplot.plot(np.linspace(100,17000,100),[lam*np.exp(-lam*wait_time) for wait_time in np.linspace(100,17000,100)]);

# Compare result to ipcoal simulations

In [48]:
# define the model with the same species tree
mod = ipcoal.Model(tre,Ne=None,recomb=1e-9)#seed_trees=1235
# simulate a gene tree with a really long chromosome
mod.sim_loci(nsites=100000000)

In [49]:
mod.df

Unnamed: 0,locus,start,end,nbps,nsnps,tidx,genealogy
0,0,0,5713,5713,0,0,((r1:11883.5153963039138...
1,0,5713,9417,3704,0,1,((r0:13767.0981465603490...
2,0,9417,28815,19398,0,2,((r0:13767.0981465603490...
3,0,28815,127525,98710,0,3,(r5:16974.84214048814465...
4,0,127525,136824,9299,0,4,(r5:16974.84214048814465...
...,...,...,...,...,...,...,...
9461,0,99962576,99966073,3497,0,9461,(r4:34196.24662000239186...
9462,0,99966073,99967147,1074,0,9462,(r3:21656.34706159862617...
9463,0,99967147,99972380,5233,0,9463,(r0:18136.76835850355565...
9464,0,99972380,99972570,190,0,9464,((r1:11216.8852356540064...


### Now for each simulated genealogy, let's see what our predicted length is:

In [52]:
exp_dists = []
for gen_idx in tqdm(range(len(mod.df))):
    currtree = toytree.tree(mod.df.genealogy[gen_idx])
    lam = get_lambda(tre, currtree, 1e-9)
    expected_dist = 1/lam
    # save a column of expected dists
    # this column corresponds to the observed "nbps" in mod.df
    exp_dists.append(expected_dist)

  0%|          | 0/9466 [00:00<?, ?it/s]

## Big result -- ipcoal waiting distances match our predictions:

Perform a simple check that the means match:

In [53]:
# take the ratio of observed to predicted
np.mean(mod.df.nbps/exp_dists)

1.0038351755676627

holy moly!

### Histogram of ipcoal nbps

In [54]:
# histogram of ipcoal-simulated distances
toyplot.bars(np.histogram(mod.df.nbps,bins=20));

### Histogram of nbps based on expectations:

In [59]:
samples_based_on_exp_dists = np.array([np.random.exponential(i) for i in exp_dists])
# histogram of expected distances
toyplot.bars(np.histogram(samples_based_on_exp_dists,bins=20));

### Next steps: topology change. This is a simple -- but tedious -- extension.