In [1]:
import toytree
import ipcoal
import numpy as np
import pandas as pd
def get_num_edges_at_time(tree, time):
    nodes_above = ([idx for idx, node in tree.idx_dict.items() if node.height > time])
    edges_above = len(nodes_above) + 1
    return edges_above
def get_tree_clade_times(tree):
    nodes_ = []
    heights_ = []
    for curr_node in tree.treenode.traverse():
        if not curr_node.is_leaf():
            nodes_.append(curr_node.get_leaf_names())
            heights_.append(curr_node.height)
    pddf = pd.DataFrame([nodes_,heights_],index=['clades','heights']).T
    return(pddf)

In [2]:
def get_tree_clade_times(tree):
    nodes_ = []
    heights_ = []
    for curr_node in tree.treenode.traverse():
        if not curr_node.is_leaf():
            nodes_.append(curr_node.get_leaf_names())
            heights_.append(curr_node.height)
    pddf = pd.DataFrame([nodes_,heights_],index=['clades','heights']).T
    return(pddf)

In [3]:
def get_branch_intervals(tr, gt, br):
    '''
    tr = species tree with Ne attribute
    gt = gene tree simulated on that species tree
    br = treenode representing a branch on the tree
    '''
    st_times = get_tree_clade_times(tr)
    gt_times = get_tree_clade_times(gt)
    coalclade = br.get_leaf_names()
    
    ###temp
    st_coal_node = tr.treenode.search_nodes(idx=tr.get_mrca_idx_from_tip_labels(br.get_leaf_names()))[0]
    nearest_st_node = st_coal_node
    while ((nearest_st_node.height + nearest_st_node.dist) < br.height):
        nearest_st_node = nearest_st_node.up
    coalclade = nearest_st_node.get_leaf_names()
    ###
    
    
    br_lower = br.height
    br_upper = br_lower + br.dist
    gt_clade_changes = (gt_times.heights < br_upper) & (gt_times.heights > br_lower)
    st_clade_changes = (st_times.heights < br_upper) & (st_times.heights > br_lower)
    st_time_diffed = st_times[st_clade_changes]
    #return(np.array([all(elem in clade for elem in coalclade) for clade in st_time_diffed.clades]))

    contains_clade = st_time_diffed[np.array([all(elem in clade for elem in coalclade) for clade in st_time_diffed.clades])]

    if not len(contains_clade.columns):
        contains_clade = pd.DataFrame(columns=['clades','heights'])
    contains_clade = pd.DataFrame([list(contains_clade.clades.append(pd.Series([coalclade]),ignore_index=True)),list(contains_clade.heights.append(pd.Series(br_lower)))],index=['clades','heights']).T
    contains_clade = contains_clade.sort_values('heights')

    all_members = []
    for i in contains_clade.clades:
        all_members.extend(i)
    all_members = np.unique(all_members)
    
    relevant_coals = pd.DataFrame(columns=["heights"])

    if np.sum(gt_clade_changes):
        potential_coals = gt_times[gt_clade_changes]
        relevant_coals = potential_coals[[set(i).issubset(all_members) for i in potential_coals.clades]]
        relevant_coals = relevant_coals.sort_values('heights')

    time_points = np.sort(list(contains_clade.heights) + list(relevant_coals.heights) + [br_upper])
    if int(time_points[-1]) == int(time_points[-2]):
        time_points = time_points[:-1]
    starts = time_points[:-1]
    stops = time_points[1:]
    lengths = stops-starts
    num_to_coal = np.repeat(1,len(starts))
    ne = np.repeat(1,len(starts))
    a_df = pd.DataFrame([starts,stops,lengths,num_to_coal,ne],index=['starts','stops','lengths','num_to_coal','ne']).T
    mids = (a_df.stops + a_df.starts)/2
    interval_reduced_trees=[]
    
    nes = []
    for mid in mids:
        clade = contains_clade.clades.iloc[np.sum(contains_clade.heights<mid)-1]

        cladeNe = tr.treenode.search_nodes(idx=tr.get_mrca_idx_from_tip_labels(clade))[0].Ne
        nes.append(cladeNe)
        reduced_tree = gt.prune(clade)
        interval_reduced_trees.append(reduced_tree.newick)
        
    a_df['reduced_trees'] = interval_reduced_trees
    a_df['mids'] = mids
    a_df['ne'] = nes
    a_df['num_to_coal'] = a_df.apply(lambda x: get_num_edges_at_time(toytree.tree(x['reduced_trees']), x['mids']), axis=1)
    
    return a_df

### let's demonstrate:

In [4]:
# make a random tree
tre = toytree.rtree.bdtree(12,time=1e6,seed=123)

In [5]:
# scale it so that branch lengths that make sense
tre = tre.mod.node_scale_root_height(treeheight=1e6)

In [6]:
# set a random Ne to each node
node_ne_dict = {i:np.random.randint(100,100000) for i in range(tre.nnodes)}
tre = tre.set_node_data('Ne',node_ne_dict)

In [7]:
tre.draw(ts='p',node_labels=True,node_sizes=15,width=500,height=500,node_mask=False);

### Now we can define a species tree model and simulate a gene tree

In [8]:
# define the model
mod = ipcoal.Model(tre,Ne=350000,seed_trees=1235)
# simulate a gene tree
mod.sim_trees(1)

### Let's look at the gene tree:

In [9]:
# extract the gene tree individually
gtr = toytree.tree(mod.df.genealogy[0])
# draw it
gtr.draw(ts='p',node_labels=True,node_sizes=15,width=500,height=500,node_mask=False);

### grab a node of a specific index from the gene tree

In [10]:
# grab node 21
mybranch = gtr.treenode.search_nodes(idx=21)[0]
print(mybranch)


   /-r0
--|
  |   /-r1
   \-|
      \-r2


### get the different `num_to_coal` and `ne` intervals for this branch

In [11]:
df = get_branch_intervals(tre,gtr,mybranch)
df

Unnamed: 0,starts,stops,lengths,num_to_coal,ne,reduced_trees,mids
0,310594.0,511354.3,200760.268966,2,67549,"((r0:310594,(r1:252445,r...",410974.2
1,511354.3,743480.6,232126.24634,2,67549,"((r0:310594,(r1:252445,r...",627417.4
2,743480.6,762956.6,19476.057159,4,1457,"((r0:310594,(r1:252445,r...",753218.6
3,762956.6,903008.5,140051.891918,3,1457,"((r0:310594,(r1:252445,r...",832982.6
4,903008.5,1000000.0,96991.492799,3,1457,"((r0:310594,(r1:252445,r...",951504.3
5,1000000.0,1154319.0,154318.777731,5,44886,"((r0:310594,(r1:252445,r...",1077159.0
6,1154319.0,1161669.0,7350.381479,4,44886,"((r0:310594,(r1:252445,r...",1157994.0
7,1161669.0,1314291.0,152621.463515,3,44886,"((r0:310594,(r1:252445,r...",1237980.0
8,1314291.0,1439421.0,125130.633044,2,44886,"((r0:310594,(r1:252445,r...",1376856.0


### now we can import parts of the math equation...

For each segment we should define variables --

In [12]:
len(df)

9

In [13]:
df.iloc[2]['ne']

1457

In [14]:
def get_i_idx(idx, df):
    ### get the i value from a particular interval index
    
    # ne for the interval
    ne_ = df.iloc[idx-1]['ne']
    # num to coal for the interval
    n2c = df.iloc[idx-1]['num_to_coal']
    
    # calculate i
    return(n2c / (2*ne_))

In [15]:
def get_sigma_idx(idx, df):
    return(df.iloc[idx-1]['stops'])

In [18]:
def get_length_idx(idx, df):
    return(df.iloc[idx-1]['lengths'])

In [20]:
df

Unnamed: 0,starts,stops,lengths,num_to_coal,ne,reduced_trees,mids
0,310594.0,511354.3,200760.268966,2,67549,"((r0:310594,(r1:252445,r...",410974.2
1,511354.3,743480.6,232126.24634,2,67549,"((r0:310594,(r1:252445,r...",627417.4
2,743480.6,762956.6,19476.057159,4,1457,"((r0:310594,(r1:252445,r...",753218.6
3,762956.6,903008.5,140051.891918,3,1457,"((r0:310594,(r1:252445,r...",832982.6
4,903008.5,1000000.0,96991.492799,3,1457,"((r0:310594,(r1:252445,r...",951504.3
5,1000000.0,1154319.0,154318.777731,5,44886,"((r0:310594,(r1:252445,r...",1077159.0
6,1154319.0,1161669.0,7350.381479,4,44886,"((r0:310594,(r1:252445,r...",1157994.0
7,1161669.0,1314291.0,152621.463515,3,44886,"((r0:310594,(r1:252445,r...",1237980.0
8,1314291.0,1439421.0,125130.633044,2,44886,"((r0:310594,(r1:252445,r...",1376856.0


In [19]:
get_length_idx(4, df)

140051.89191785455

In [35]:
for i in range(5,5):
    print(i)

In [61]:
def get_pij(interval_index, j, df):
    curri = get_i_idx(interval_index, df)
    currsigma = get_sigma_idx(interval_index, df)
    currj = get_i_idx(j, df)
    currjsigma = get_sigma_idx(j,df)
    currjlen = get_length_idx(j, df)
    
    mult = -curri*currsigma
    
    
    
    m = interval_index + 1
    summation = 0
    if (j-1) >= m:
        for _iter_ in range(m,j): # j-1 in math, but to j here because python range is not inclusive
            tempi = get_i_idx(_iter_, df)
            #tempsig = get_sigma_idx(_iter_, df)
            templen = get_length_idx(_iter_, df)
            tempprod = tempi*templen
            summation += tempprod
    
    #return(mult,summation)
    exp_portion = np.exp(mult - summation)
    
    numer = 1-np.exp(-currj*currjlen)
    denom = 1/currj
    
    return(exp_portion * numer * denom)
    

In [50]:
def get_pii(interval_index, df):
    curri = get_i_idx(interval_index, df)
    currsigma = get_sigma_idx(interval_index, df)
    
    return (-1 * (1/curri) * np.exp(-1*curri*currsigma))

### Prob of tree unchanged given time, branch, and topology

In [45]:
t = 311000

In [56]:
interval_idx_at_t = np.sum(df.starts < t)
interval_idx_at_t

1

In [64]:
i_at_t = get_i_idx(interval_idx_at_t,df)

pij_sum = 0
for j in range(interval_idx_at_t,len(df)+1):
    pij_sum += get_pij(interval_idx_at_t,j,df)

1/i_at_t + np.exp(i_at_t * t)*(get_pii(interval_idx_at_t,df) + pij_sum)

70739.29594656517

In [54]:
interval_idx_at_t

0

In [62]:
pij_sum = 0
for j in range(interval_idx_at_t,len(df)+1):
    pij_sum += get_pij(interval_idx_at_t,j,df)

In [63]:
pij_sum

66.76841842192805

In [58]:
np.exp(i_at_t * t)*get_pii(interval_idx_at_t,df)

-3479.172134404753

In [36]:
get_pii(4, df)

-0.05851617415242709

In [21]:
get_i_idx(5,df)

4.558051341890315e-05

Start with getting the total branch length for the outside of the equation...

In [24]:
outside_term_b = 1/mybranch.dist

In [25]:
outside_term_b

8.858751707321883e-07

In [26]:
kb = len(df)

In [27]:
for xt in range(1,kb+1):
    

SyntaxError: unexpected EOF while parsing (2446594651.py, line 1)

In [None]:
for i in range()