In [1]:
from Bio import Phylo
from io import StringIO
import networkx as nx
import pandas as pd
from utilities import *
import matplotlib.pyplot as plt

problin_file = "/n/fs/ragr-research/projects/problin_experiments/Real_biodata/run_kptracer/result_trees/pruned_1461/problin_1461.nwk"
startle_file = "/n/fs/ragr-research/projects/problin_experiments/Real_biodata/run_kptracer/result_trees/pruned_1461/startle_1461.nwk"
casshybrid_file = "/n/fs/ragr-research/projects/problin_experiments/Real_biodata/run_kptracer/result_trees/pruned_1461/casshybrid_1461.nwk"
df_kp_meta = pd.read_csv(f'/n/fs/ragr-data/users/palash/multi-linTracer/kp_data/KPTracer-Data/KPTracer_meta.csv', index_col = 0)
df_character_matrix_Fam_1 = pd.read_csv('/n/fs/ragr-data/users/palash/multi-linTracer/kp_data/KPTracer-Data/trees/3724_NT_All_character_matrix.txt',
                                        index_col = 0, sep='\t', dtype=str)
df_character_matrix_Fam_1 = df_character_matrix_Fam_1.replace('-', '-1')

In [2]:
cell_site_info_1 = df_kp_meta.loc[df_character_matrix_Fam_1.index]['SubTumor']
for cell, site in cell_site_info_1.items():
    if len(site.split('_')) > 3:
        cell_site_info_1[cell] = '_'.join(site.split('_')[:-1])
        
cell_site_info_2 = df_kp_meta.loc[df_character_matrix_Fam_1.index]['SubTumor']
for cell, site in cell_site_info_2.items():
    if len(site.split('_')) > 3: 
        cell_site_info_2[cell] = '_'.join(site.split('_')[-2:])

In [3]:
def get_brlen(u, v, edge_attr):
    return edge_attr['weight']
    
def computeG(T, cell_site_info, state_list, cost_labeling, presence_labeling, node):
    min_cost = min(cost_labeling[node])
    if node == 'root':
         for state_idx in range(len(state_list)):
                if cost_labeling[node][state_idx] == min_cost:
                    presence_labeling[node][state_idx] = 1
                else:
                    presence_labeling[node][state_idx] = 0
    else:
        
        parent_node = list(T.pred[node])[0]
        
        for state_idx in range(len(state_list)):
            presence_labeling[node][state_idx] = 0
        
        for parent_state_idx in range(len(state_list)):
            if presence_labeling[parent_node][parent_state_idx] == 1:

                min_cost = np.inf
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[node][child_state_idx]
                    else:
                        curr_cost = cost_labeling[node][child_state_idx] + 1

                    if curr_cost < min_cost:
                        min_cost = curr_cost
        
                for child_state_idx in range(len(state_list)):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[node][child_state_idx]
                    else:
                        curr_cost = cost_labeling[node][child_state_idx] + 1
                        
                    if curr_cost == min_cost:
                        presence_labeling[node][child_state_idx] = 1
                

    for child in T[node]:
        computeG(T, cell_site_info, state_list, cost_labeling, presence_labeling, child)
        
def update_parsimonious_labeling_counts(T, cell_site_info, state_list, cost_labeling, count_labeling, node):
    if is_leaf(T, node):
        for state_idx, state in enumerate(state_list):
            if state == cell_site_info[node]:
                cost_labeling[node][state_idx] = 0
                count_labeling[node][state_idx] = 1
            else:
                cost_labeling[node][state_idx] = np.inf
                count_labeling[node][state_idx] = 0
    else:
        for child in T[node]:
            update_parsimonious_labeling_counts(T, cell_site_info, state_list, cost_labeling, count_labeling, child)
        
        for parent_state_idx, state in enumerate(state_list):
            cost_labeling[node][parent_state_idx] = 0
            count_labeling[node][parent_state_idx] = 1
            for child in T[node]:
                min_cost = np.inf
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[child][child_state_idx]
                    else:
                        curr_cost = cost_labeling[child][child_state_idx] + 1

                    if curr_cost < min_cost:
                        min_cost = curr_cost
                        min_child_state_idx = child_state_idx

                cost_labeling[node][parent_state_idx] += min_cost

                curr_child_total_count = 0
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[child][child_state_idx]
                    else:
                        curr_cost = cost_labeling[child][child_state_idx] + 1

                    if curr_cost == min_cost:
                        curr_child_total_count += count_labeling[child][child_state_idx]

                count_labeling[node][parent_state_idx] *= curr_child_total_count

def count_parsimonious_labelings(T, cell_site_info):
    state_list = list(cell_site_info.unique())
    nstates = len(state_list)
    # state_transition_counter = Counter()
    node_cost_labeling = {node:[None]*nstates for node in T.nodes}
    # print(T.nodes)
    # print(node_cost_labeling['root'])
    node_count_labeling = {node:[None]*nstates for node in T.nodes}
    update_parsimonious_labeling_counts(T, cell_site_info, state_list, node_cost_labeling, node_count_labeling, 'root')
    
    return node_cost_labeling, node_count_labeling
    
def count_transitions(input_tree, cell_site_info, throw_out_identical_sites=True):
    num_leaves = 0
    for n in input_tree.nodes:
        if is_leaf(input_tree, n): 
            num_leaves += 1
    
    print(f"input tree has {len(input_tree.nodes)} nodes and {num_leaves} leaves")
    if throw_out_identical_sites:
        input_tree = get_smaller_tree(input_tree, cell_site_info_1)
        print(f"modified tree has {len(input_tree.nodes)} nodes") 
    
    num_leaves = 0
    for n in input_tree.nodes:
        if is_leaf(input_tree, n): 
            num_leaves += 1
    print(f"modified tree has {num_leaves} leaves")
    cost_labeling, count_labeling = count_parsimonious_labelings(input_tree, cell_site_info)
    state_list = list(cell_site_info.unique())
    nstates = len(state_list)
    presence_labeling = {node:[None]*nstates for node in input_tree.nodes}
    computeG(input_tree, cell_site_info, state_list, cost_labeling, presence_labeling, 'root')

    optimal_labeling = {node:'' for node in input_tree.nodes}

    for node in input_tree.nodes:
        if presence_labeling[node][0] == 1:
            optimal_labeling[node] = 0
        else:
            optimal_labeling[node] = np.where(presence_labeling[node])[0][0]

    num_transitions = Counter()
    for edge in input_tree.edges:
        if optimal_labeling[edge[0]] != optimal_labeling[edge[1]]:
            num_transitions[(optimal_labeling[edge[0]], optimal_labeling[edge[1]])] += 1

    return optimal_labeling, presence_labeling, num_transitions, sum(num_transitions.values())


In [4]:
tree_path = problin_file
if tree_path.startswith('[&R]'):
    handle = StringIO(tree_path)
    phylo_tree = Phylo.read(handle, 'newick')
    net_tree = Phylo.to_networkx(phylo_tree)
else:
    phylo_tree = Phylo.read(tree_path, 'newick')
    net_tree = Phylo.to_networkx(phylo_tree)

In [5]:
root = list(net_tree.nodes)[0]
new_net_tree = net_tree.copy()
node_renaming_mapping = {}
name_to_node = {}
idx = 0
for node in net_tree.nodes:
    if str(node) == 'Clade':
        node_renaming_mapping[node] = f'clade_{idx}'
        name_to_node[f'clade_{idx}'] = node
        idx = idx + 1
    else:
        node_renaming_mapping[node] = str(node)
        name_to_node[str(node)] = node
node_renaming_mapping[list(net_tree.nodes)[0]] = 'root'


In [6]:
new_net_tree = nx.relabel_nodes(new_net_tree, node_renaming_mapping)
for node in net_tree.nodes:
    if str(node) == 'Clade':
        node_renaming_mapping[node] = f'clade_{idx}'
        idx = idx + 1
    else:
        node_renaming_mapping[node] = str(node)

name_to_node['root'] = name_to_node['clade_0']

attrs = {}
for new_node in new_net_tree.nodes:
    #print("new_node", new_node)
    node = name_to_node[new_node]
    #print(new_net_tree.edges(new_node))
    for edge in new_net_tree.edges(new_node):
        #print("edge", edge)
        u, v = edge
        weight = net_tree[name_to_node[u]][name_to_node[v]]['weight']
        attrs[(u, v)] = {'weight': weight}

In [7]:

nx.set_edge_attributes(new_net_tree, attrs)
directed_tree = nx.DiGraph()
directed_tree.add_edges_from(list(nx.bfs_edges(new_net_tree, 'root')))
nx.set_edge_attributes(directed_tree, attrs)


In [8]:
problin_1461 = directed_tree
mapping = dict()
for i, name in enumerate(cell_site_info_1.unique()): 
    mapping[i] = name

In [9]:
# get the number of metastases over # lineages

input_tree = problin_1461
# def count_transition_heights(input_tree, use_weights=False): 
transition_heights = dict()
a_optimal_labeling, a_presence_labeling, a_transitions, a_cost = count_transitions(input_tree, cell_site_info_1, False)
for edge in input_tree.edges:
    branch_len = input_tree.edges[edge]['weight']

    from_label = mapping[a_optimal_labeling[edge[0]]].strip()
    to_label = mapping[a_optimal_labeling[edge[1]]].strip()

    if from_label != to_label:
        from_primary = True if from_label == '3724_NT_T1' else False
        to_primary = True if to_label == '3724_NT_T1' else False
    
        if from_primary and to_primary:
            branch_type = "p2p_transition"
        elif from_primary and not to_primary:
            branch_type = "p2n_transition"
        elif not from_primary and not to_primary:
            branch_type = "n2n_transition"
            #print(from_label, to_label)
        elif not from_primary and to_primary:
            branch_type = "n2p_transition"
            # print(branch_type)
        else:
            branch_type = "hm?"
    else:
        branch_type = "not_transition"
    start_branch_height = nx.shortest_path_length(input_tree, source='root', target=edge[0], weight=get_brlen)
    end_branch_height = nx.shortest_path_length(input_tree, source='root', target=edge[1], weight=get_brlen)
    
    transition_heights[edge] = (start_branch_height, end_branch_height, branch_type)


input tree has 2921 nodes and 1461 leaves
modified tree has 1461 leaves


In [10]:
transition_heights

{('root', 'clade_1'): (0, 0.19311067051961742, 'not_transition'),
 ('root', 'clade_7'): (0, 0.005, 'not_transition'),
 ('clade_1', 'clade_2'): (0.19311067051961742,
  0.6578658037152288,
  'not_transition'),
 ('clade_1', 'L8.CGGAGCTTCGGAAATA-1'): (0.19311067051961742,
  2.4760264849530227,
  'not_transition'),
 ('clade_7', 'clade_8'): (0.005, 0.7928975489609431, 'not_transition'),
 ('clade_7', 'clade_10'): (0.005, 0.01, 'not_transition'),
 ('clade_2', 'clade_3'): (0.6578658037152288,
  1.042211541826986,
  'not_transition'),
 ('clade_2', 'L9.CTGCGGAGTAGGCTGA-1'): (0.6578658037152288,
  2.4760264849530227,
  'not_transition'),
 ('clade_8', 'clade_9'): (0.7928975489609431,
  1.3899931442025095,
  'not_transition'),
 ('clade_8', 'L8.TCACGAATCCCATTAT-1'): (0.7928975489609431,
  2.4760272189148385,
  'not_transition'),
 ('clade_10', 'clade_11'): (0.01, 0.1360013610013078, 'not_transition'),
 ('clade_10', 'clade_815'): (0.01, 0.015, 'not_transition'),
 ('clade_3', 'L9.GAACATCTCTCTAGGA-1'): (

In [20]:
# burst is 1.75 to 2.25 
scaled_intervals = [0.0, 1, 5, 6]
scale_factor = 6/2.47
intervals = [x/scale_factor for x in scaled_intervals]

In [21]:
epochs = ['Primary Growth', 'Metastasis Burst', 'Late Metastasis']


In [13]:
# # count the metastasis events in each epoch

# epoch_lineage_count = dict()
# epoch_met_count = dict()
# epoch_reseeding_count = dict()
# for epoch in epochs:
#     epoch_met_count[epoch] = 0
#     epoch_lineage_count[epoch] = 0
#     epoch_reseeding_count[epoch] = 0
    
# for edge in transition_heights:
#     branch_start, branch_end, branch_type = transition_heights[edge]
    
#     # find which epoch it lives in
#     # add to that epoch's count
#     branch_length = branch_end - branch_start
#     branch_interval = pd.Interval(branch_start, branch_end, closed='both')
#     for i, interval_start in enumerate(intervals[:-1]):
#         interval_end = intervals[i+1]
#         interval = pd.Interval(interval_start, interval_end, closed='both')
#         if interval.overlaps(branch_interval):
#             if branch_type == 'p2n_transition':
#                 epoch_met_count[epochs[i]] += 1
#             elif branch_type == 'n2p_transition':
#                 epoch_reseeding_count[epochs[i]] += 1
            
#             # if branch_type == "n2p_transition" or branch_type == "p2n_transition":
#             epoch_lineage_count[epochs[i]] += 1

            
                

In [103]:
# print("Metastasis:", epoch_met_count)
# print("Reseedings:", epoch_reseeding_count)
# print("Lineages:", epoch_lineage_count)

Metastasis: {'Primary Growth': 0, 'Metastasis Burst': 45, 'Late Metastasis': 37}
Reseedings: {'Primary Growth': 0, 'Metastasis Burst': 27, 'Late Metastasis': 25}
Lineages: {'Primary Growth': 44, 'Metastasis Burst': 2443, 'Late Metastasis': 1557}


In [105]:
# 25/(25 + 37), 37/(25 + 37)

(0.4032258064516129, 0.5967741935483871)

In [66]:
# print("Primary Growth", 0/44/1)
# print("Metastasis Burst", 16/829/2)
# print("Late Metastasis", 45/2364/3)

Primary Growth 0.0
Metastasis Burst 0.009650180940892641
Late Metastasis 0.00634517766497462


In [69]:
# # background rate of metastasis
# (3+12+47)/(86+212+2662)/6

0.0034909909909909913

In [71]:
# # bursting window in Metastasis stage
# 12/212/(.5)

0.11320754716981132

In [72]:
# # ratio of metastasis burst window to background rate of metastasis
# 0.11320754716981132/0.0034909909909909913

32.42848447961047

In [60]:
# (0.009650180940892641-0.00634517766497462)/0.009650180940892641

0.34248096446700493

In [14]:
def find_epoch(intervals, epochs, branch_start, branch_end):
    epoch_overlap = []
        
    branch_length = branch_end - branch_start
    branch_interval = pd.Interval(branch_start, branch_end, closed='both')
    for i, interval_start in enumerate(intervals[:-1]):
        interval_end = intervals[i+1]
        interval = pd.Interval(interval_start, interval_end, closed='both')
        epoch_overlap.append(interval.overlaps(branch_interval))
    return epoch_overlap

In [15]:
# # get all leaf incident edges
# with open("epoch_leaf_stats.csv", "w+") as w:
#     w.write("branch_length,epoch\n")
#     for edge in transition_heights:
#         f, t = edge
#         branch_start, branch_end, branch_type = transition_heights[edge]
#         if not t.startswith('clade') and not t.startswith('root'):
#             # is leaf
#             branch_length = branch_end - branch_start
#             epoch_overlap_bool = find_epoch(intervals, epochs, branch_start, branch_end)
#             for i, is_overlap in enumerate(epoch_overlap_bool):
#                 w.write(','.join([str(xx) for xx in [branch_length, epochs[i]]]) + '\n')
#                 break

In [33]:
epochs = ['Primary Growth', 'Metastasis', 'Late Metastasis']
# burst is 1.75 to 2.25 
scaled_intervals = [0.0, 1, 3, 6]
scale_factor = 6/2.47
intervals = [x/scale_factor for x in scaled_intervals]
intervals

[0.0, 0.4116666666666667, 1.235, 2.47]

In [35]:
# get all edges that lie entirely in one epoch
with open("epoch_all_stats.csv", "w+") as w:
    w.write("branch_length,epoch\n")
    for edge in transition_heights:
        f, t = edge
        branch_start, branch_end, branch_type = transition_heights[edge]
        branch_length = branch_end - branch_start
        epoch_overlap_bool = find_epoch(intervals, epochs, branch_start, branch_end)
        # only record if overlaps 1
        if sum(epoch_overlap_bool) == 1:
            for i, bool in enumerate(epoch_overlap_bool):
                if bool: 
                    w.write(','.join([str(xx) for xx in [branch_length, epochs[i]]]) + '\n')

        branch_interval = pd.Interval(branch_start, branch_end, closed='both')
        burst_interval = pd.Interval(1.75, 2.25, closed='both')
        before_burst_interval = pd.Interval(0, 1.75, closed='both')
        after_burst_interval = pd.Interval(2.25, 6, closed='both')
        if burst_interval.overlaps(branch_interval) and not before_burst_interval.overlaps(branch_interval) and not after_burst_interval.overlaps(branch_interval):
            w.write(','.join([str(xx) for xx in [branch_length, 'Metastasis Burst']]) + '\n')

In [39]:
epoch_met_count, epoch_lineage_count

({'Primary Growth': 4, 'Metastasis Burst': 13, 'Late Metastasis': 47},
 {'Primary Growth': 88, 'Metastasis Burst': 527, 'Late Metastasis': 2504})

In [44]:
print("Primary Growth", 4/88/1.8)
print("Metastasis Burst", 13/527/(2.7 - 1.8))
print("Late Metastasis", 47/2504/(6 - 2.7))

Primary Growth 0.025252525252525252
Metastasis Burst 0.027408812987560614
Late Metastasis 0.005687869106399459


In [22]:
# burst is 1.75 to 2.25 
scaled_intervals = [0.0, 1, 1.75, 2.25, 3, 6]
scale_factor = 6/2.47
intervals = [x/scale_factor for x in scaled_intervals]
intervals

[0.0, 0.4116666666666667, 0.7204166666666667, 0.92625, 1.235, 2.47]

In [23]:

all_transition_intervals = [transition_heights[x] for x in transition_heights]
# sort by starting 
all_transition_intervals.sort(key=lambda x: x[0])

all_transition_counts = dict()
for interval_start in intervals:
    all_transition_counts[interval_start] = []
    

In [24]:
all_transition_counts

{0.0: [],
 0.4116666666666667: [],
 0.7204166666666667: [],
 0.92625: [],
 1.235: [],
 2.47: []}

In [25]:
for i, interval_start in enumerate(intervals[:-1]):
    interval_end = intervals[i+1]
    interval = pd.Interval(interval_start, interval_end, closed='both')    
    interval_list = all_transition_counts[interval_start]
    for i in range(len(all_transition_intervals)):
        branch_start, branch_end, branch_transition_type = all_transition_intervals[i]
        branch_interval = pd.Interval(branch_start, branch_end, closed='both')
        
        if interval.overlaps(branch_interval):
            branch_length = branch_end - branch_start
            interval_list.append(branch_length)
        # else:
        #     break
    print(all_transition_counts)

{0.0: [0.19311067051961742, 0.005, 0.7878975489609431, 0.005, 0.1260013610013078, 0.004999999999999999, 0.11317298282909295, 0.005000000000000001, 0.005000000000000001, 0.0732079371446571, 0.00656203629714415, 0.3426195061621374, 0.15960905777811843, 2.4444676278964788, 0.0050000000000000044, 2.372604046732874, 0.08608758212655923, 0.0050000000000000044, 2.372822605634022, 0.09826523963712627, 0.3421027396523065, 0.007834202777416116, 0.12719259066899338, 0.8834217999120579, 0.2793610480308598, 2.3400213569905084, 0.30096176048765955, 0.2799862732873783, 0.3014199643365111, 1.7764992906932011, 0.46475513319561146, 2.282915814433405, 0.654759545422942, 0.09947074293105412, 0.201510948889362, 0.10509185598437326, 0.10411422024592998, 2.1750866551382932, 2.108410160982241, 2.108410160982241, 0.09390949094370116, 0.34403846468739974, 0.2238929323914972, 0.11500744310954081], 0.4116666666666667: [], 0.7204166666666667: [], 0.92625: [], 1.235: [], 2.47: []}
{0.0: [0.19311067051961742, 0.005,

In [27]:
# epochs = ['Primary Growth', 'Pre-Burst', 'Metastasis Burst', 'Post-Burst', 'Late Metastasis']
# with open("epoch_stats.csv", "w+") as w:
#     w.write("interval_start,branch_length,epoch\n")
#     for i, interval_start in enumerate(all_transition_counts):
#         interval_list = all_transition_counts[interval_start]
#         for branch_length in interval_list:    
#             w.write(','.join([str(xx) for xx in [interval_start, branch_length, epochs[i]]]) + '\n')