In [1]:
import pandas as pd
import sys
import argparse
import itertools
import math
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import cassiopeia as cass
import os
import json
import networkx as nx
import itertools
import ete3

import itertools
from collections import Counter
from collections import defaultdict
from collections import deque
from Bio import Phylo

import subprocess

from utilities import *
%load_ext autoreload
%autoreload 2

import gurobipy as gp

# Helper Functions

In [2]:
def get_smaller_tree(tree, cell_site_info):

    mod_tree = tree.copy()
    leaves_to_remove = []
    for node in mod_tree.nodes:
        leaf_site_list = []
        
        for child in mod_tree[node]:
            if is_leaf(mod_tree, child):
                leaf_site = cell_site_info[child]
                if leaf_site not in leaf_site_list:
                    leaf_site_list.append(leaf_site)
                else:
                    leaves_to_remove.append(child)
    
    for leaf in leaves_to_remove:
        mod_tree.remove_node(leaf)
    print(f"removing {len(leaves_to_remove)}")
    return mod_tree

In [3]:
def update_parsimonious_labeling_counts(T, cell_site_info, state_list, cost_labeling, count_labeling, node):
    if is_leaf(T, node):
        for state_idx, state in enumerate(state_list):
            if state == cell_site_info[node]:
                cost_labeling[node][state_idx] = 0
                count_labeling[node][state_idx] = 1
            else:
                cost_labeling[node][state_idx] = np.inf
                count_labeling[node][state_idx] = 0
    else:
        for child in T[node]:
            update_parsimonious_labeling_counts(T, cell_site_info, state_list, cost_labeling, count_labeling, child)
        
        for parent_state_idx, state in enumerate(state_list):
            cost_labeling[node][parent_state_idx] = 0
            count_labeling[node][parent_state_idx] = 1
            for child in T[node]:
                min_cost = np.inf
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[child][child_state_idx]
                    else:
                        curr_cost = cost_labeling[child][child_state_idx] + 1

                    if curr_cost < min_cost:
                        min_cost = curr_cost
                        min_child_state_idx = child_state_idx

                cost_labeling[node][parent_state_idx] += min_cost

                curr_child_total_count = 0
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[child][child_state_idx]
                    else:
                        curr_cost = cost_labeling[child][child_state_idx] + 1

                    if curr_cost == min_cost:
                        curr_child_total_count += count_labeling[child][child_state_idx]

                count_labeling[node][parent_state_idx] *= curr_child_total_count


In [4]:
def count_parsimonious_labelings(T, cell_site_info):
    state_list = list(cell_site_info.unique())
    nstates = len(state_list)
    # state_transition_counter = Counter()
    node_cost_labeling = {node:[None]*nstates for node in T.nodes}
    # print(T.nodes)
    # print(node_cost_labeling['root'])
    node_count_labeling = {node:[None]*nstates for node in T.nodes}
    update_parsimonious_labeling_counts(T, cell_site_info, state_list, node_cost_labeling, node_count_labeling, 'root')
    
    return node_cost_labeling, node_count_labeling

In [5]:
def computeG(T, cell_site_info, state_list, cost_labeling, presence_labeling, node):
    min_cost = min(cost_labeling[node])
    if node == 'root':
         for state_idx in range(len(state_list)):
                if cost_labeling[node][state_idx] == min_cost:
                    presence_labeling[node][state_idx] = 1
                else:
                    presence_labeling[node][state_idx] = 0
    else:
        
        parent_node = list(T.pred[node])[0]
        
        for state_idx in range(len(state_list)):
            presence_labeling[node][state_idx] = 0
        
        for parent_state_idx in range(len(state_list)):
            if presence_labeling[parent_node][parent_state_idx] == 1:

                min_cost = np.inf
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[node][child_state_idx]
                    else:
                        curr_cost = cost_labeling[node][child_state_idx] + 1

                    if curr_cost < min_cost:
                        min_cost = curr_cost
        
                for child_state_idx in range(len(state_list)):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[node][child_state_idx]
                    else:
                        curr_cost = cost_labeling[node][child_state_idx] + 1
                        
                    if curr_cost == min_cost:
                        presence_labeling[node][child_state_idx] = 1
                

    for child in T[node]:
        computeG(T, cell_site_info, state_list, cost_labeling, presence_labeling, child)

# Read in Inputs

In [6]:
df_kp_meta = pd.read_csv(f'/n/fs/ragr-data/users/palash/multi-linTracer/kp_data/KPTracer-Data/KPTracer_meta.csv', index_col = 0)

In [7]:
df_character_matrix_Fam_1 = pd.read_csv('/n/fs/ragr-data/users/palash/multi-linTracer/kp_data/KPTracer-Data/trees/3724_NT_All_character_matrix.txt',
                                        index_col = 0, sep='\t', dtype=str)
df_character_matrix_Fam_1 = df_character_matrix_Fam_1.replace('-', '-1')

In [8]:
cell_site_info_1 = df_kp_meta.loc[df_character_matrix_Fam_1.index]['SubTumor']
for cell, site in cell_site_info_1.items():
    # if site.startswith('3513_NT_T'):
        # cell_site_info[cell] = 'primary'
    if len(site.split('_')) > 3:
        cell_site_info_1[cell] = '_'.join(site.split('_')[:-1])

In [9]:
cell_site_info_1.to_csv('cell_annotations.csv', sep=',')

In [10]:
# 21,108 nodes, unresolved
cass_T_Fam = from_newick_get_nx_tree('/n/fs/ragr-data/users/palash/multi-linTracer/kp_data/KPTracer-Data/trees/3724_NT_All_tree.nwk')

# 1460 nodes, unresolved
cass_T_Fam_pruned = from_newick_get_nx_tree('/n/fs/ragr-research/projects/problin_experiments/Real_biodata/data_kptracer/all/3724_NT_All/3724_NT_All_tree_pruned.nwk')

# 21,108 nodes, unresolved
linTracer_T_Fam = from_newick_get_nx_tree('/n/fs/ragr-research/projects/lineage-tracing-nni/data/kp-tracer/3724_NT_All_published_seed/fixed_nni_tree.nwk')

# 1460 nodes, resolved
problin_startle = from_newick_get_nx_tree('/n/fs/ragr-research/projects/problin_experiments/Real_biodata/run_kptracer/runjobs/3724_NT_All/problin_startle/EM_problin_startle_toposearch2_rep2._ckpt.208.txt.nwk')

# 1460 nodes, unresolved
linTracer_T_Fam_pruned = from_newick_get_nx_tree('/n/fs/ragr-research/projects/problin_experiments/Real_biodata/data_kptracer/all/3724_NT_All/fixed_nni_tree_pruned.nwk')


In [11]:
# 21,108 nodes, unresolved
problin_startle_full_tree = from_newick_get_nx_tree('/n/fs/ragr-research/projects/problin_experiments/Real_biodata/notebooks/problin_startle_full_tree.nwk')

In [12]:
# 21,107 nodes, resolved
problin_startle_full_tree_pars_resolved = from_newick_get_nx_tree('problin_startle_full_tree.pars_resolved.nwk')

In [13]:
# 21,108 nodes, resolved
startle_pars_resolved = from_newick_get_nx_tree('startle.pars_resolved.nwk')

In [14]:
startle_21108 = from_newick_get_nx_tree('startle_full_tree_21108.nwk')
problin_21108 = from_newick_get_nx_tree('problin_startle_full_tree_21108.nwk')
cassh_21108 = from_newick_get_nx_tree('casshybrid_full_tree_21108.nwk')

startle_idr_21108 = from_newick_get_nx_tree('startle_idresolved_21108.nwk')
problin_idr_21108 = from_newick_get_nx_tree('problin_startle_idresolved_21108.nwk')
cassh_idr_21108 = from_newick_get_nx_tree('casshybrid_idresolved_21108.nwk')

# 1461 nodes
problin_1461 = from_newick_get_nx_tree('/n/fs/ragr-research/projects/problin_experiments/Real_biodata/run_kptracer/runjobs/3724_NT_All/problin_benchmarks/problin_1461.txt.nwk')
startle_1461 = from_newick_get_nx_tree('/n/fs/ragr-research/projects/problin_experiments/Real_biodata/data_kptracer/all/3724_NT_All/pruned/fixed_nni_tree_pruned.nwk')
cassh_1461 = from_newick_get_nx_tree('/n/fs/ragr-research/projects/problin_experiments/Real_biodata/data_kptracer/all/3724_NT_All/pruned/casshybrid_tree.nwk')


# Set up Functions

In [15]:

def count_transitions(input_tree, cell_site_info, throw_out_identical_sites=True):
    num_leaves = 0
    for n in input_tree.nodes:
        if is_leaf(input_tree, n): 
            num_leaves += 1
    
    print(f"input tree has {len(input_tree.nodes)} nodes and {num_leaves} leaves")
    if throw_out_identical_sites:
        input_tree = get_smaller_tree(input_tree, cell_site_info_1)
        print(f"modified tree has {len(input_tree.nodes)} nodes") 
    
    num_leaves = 0
    for n in input_tree.nodes:
        if is_leaf(input_tree, n): 
            num_leaves += 1
    print(f"modified tree has {num_leaves} leaves")
    cost_labeling, count_labeling = count_parsimonious_labelings(input_tree, cell_site_info)
    state_list = list(cell_site_info.unique())
    nstates = len(state_list)
    presence_labeling = {node:[None]*nstates for node in input_tree.nodes}
    computeG(input_tree, cell_site_info, state_list, cost_labeling, presence_labeling, 'root')

    optimal_labeling = {node:'' for node in input_tree.nodes}

    for node in input_tree.nodes:
        if presence_labeling[node][0] == 1:
            optimal_labeling[node] = 0
        else:
            optimal_labeling[node] = np.where(presence_labeling[node])[0][0]

    num_transitions = Counter()
    for edge in input_tree.edges:
        if optimal_labeling[edge[0]] != optimal_labeling[edge[1]]:
            num_transitions[(optimal_labeling[edge[0]], optimal_labeling[edge[1]])] += 1

    return num_transitions, sum(num_transitions.values())


# Full Trees Identically Resolved

In [16]:
count_transitions(problin_idr_21108, cell_site_info_1, False)

input tree has 24029 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 4): 72,
          (0, 1): 30,
          (0, 3): 17,
          (1, 0): 150,
          (1, 2): 2,
          (0, 2): 1244,
          (2, 0): 18,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 0): 20,
          (1, 4): 2,
          (3, 4): 1,
          (4, 0): 3,
          (1, 3): 1}),
 1568)

In [17]:
count_transitions(cassh_idr_21108, cell_site_info_1, False)

input tree has 23908 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 21,
          (0, 4): 71,
          (0, 3): 13,
          (1, 0): 173,
          (3, 0): 28,
          (0, 2): 1261,
          (1, 4): 2,
          (1, 3): 1,
          (1, 2): 2,
          (2, 4): 1,
          (2, 0): 14,
          (4, 0): 4,
          (3, 4): 1,
          (3, 2): 6,
          (3, 1): 1}),
 1599)

In [18]:
count_transitions(startle_idr_21108, cell_site_info_1, False)

input tree has 22980 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 21,
          (1, 0): 162,
          (0, 4): 72,
          (0, 3): 15,
          (0, 2): 1260,
          (1, 2): 2,
          (1, 3): 1,
          (1, 4): 2,
          (3, 0): 24,
          (2, 0): 14,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 4): 1,
          (4, 0): 3}),
 1585)

# Full Trees Polytomies

In [19]:
count_transitions(problin_21108, cell_site_info_1, False)

input tree has 24029 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 4): 72,
          (0, 1): 30,
          (0, 3): 17,
          (1, 0): 150,
          (1, 2): 2,
          (0, 2): 1244,
          (2, 0): 18,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 0): 20,
          (1, 4): 2,
          (3, 4): 1,
          (4, 0): 3,
          (1, 3): 1}),
 1568)

In [20]:
count_transitions(cassh_21108, cell_site_info_1, False)

input tree has 23908 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 21,
          (0, 4): 71,
          (0, 3): 13,
          (1, 0): 173,
          (3, 0): 28,
          (0, 2): 1261,
          (1, 4): 2,
          (1, 3): 1,
          (1, 2): 2,
          (2, 4): 1,
          (2, 0): 14,
          (4, 0): 4,
          (3, 4): 1,
          (3, 2): 6,
          (3, 1): 1}),
 1599)

In [21]:
count_transitions(startle_21108, cell_site_info_1, False)

input tree has 22980 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 21,
          (1, 0): 162,
          (0, 4): 72,
          (0, 3): 15,
          (0, 2): 1260,
          (1, 2): 2,
          (1, 3): 1,
          (1, 4): 2,
          (3, 0): 24,
          (2, 0): 14,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 4): 1,
          (4, 0): 3}),
 1585)

# 1461 Leaves

In [22]:

count_transitions(cassh_1461, cell_site_info_1, False)

input tree has 2800 nodes and 1461 leaves
modified tree has 1461 leaves


(Counter({(0, 1): 15,
          (0, 4): 9,
          (0, 3): 7,
          (1, 0): 57,
          (3, 0): 11,
          (0, 2): 36,
          (1, 2): 1}),
 136)

In [23]:

count_transitions(problin_1461, cell_site_info_1, False)

input tree has 2921 nodes and 1461 leaves
modified tree has 1461 leaves


(Counter({(0, 4): 9,
          (0, 1): 15,
          (0, 3): 11,
          (1, 0): 37,
          (0, 2): 19,
          (3, 0): 4,
          (2, 0): 3,
          (1, 2): 1}),
 99)

In [24]:

count_transitions(startle_1461, cell_site_info_1, False)

input tree has 1872 nodes and 1461 leaves
modified tree has 1461 leaves


(Counter({(0, 1): 9,
          (1, 0): 48,
          (0, 4): 9,
          (0, 3): 9,
          (0, 2): 35,
          (1, 2): 1,
          (3, 0): 8}),
 119)

## Cassiopeia-Hybrid (Published)

In [25]:
count_transitions(cass_T_Fam, cell_site_info_1, True)

input tree has 22748 nodes and 21108 leaves
removing 19548
modified tree has 3200 nodes
modified tree has 1560 leaves


(Counter({(0, 1): 19,
          (0, 4): 15,
          (0, 3): 12,
          (1, 0): 63,
          (0, 2): 85,
          (1, 4): 2,
          (3, 0): 11,
          (1, 2): 2,
          (1, 3): 1,
          (3, 2): 1,
          (3, 4): 1,
          (3, 1): 1}),
 213)

In [26]:
count_transitions(cass_T_Fam, cell_site_info_1, False)

input tree has 22748 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 19,
          (0, 4): 71,
          (0, 3): 13,
          (1, 0): 175,
          (0, 2): 1261,
          (1, 4): 2,
          (3, 0): 28,
          (1, 2): 2,
          (1, 3): 1,
          (2, 0): 15,
          (2, 4): 1,
          (3, 2): 6,
          (3, 4): 1,
          (4, 0): 4,
          (3, 1): 1}),
 1600)

In [27]:
count_transitions(cass_T_Fam_pruned, cell_site_info_1, True)

input tree has 2798 nodes and 1460 leaves
removing 258
modified tree has 2540 nodes
modified tree has 1202 leaves


(Counter({(0, 1): 12,
          (0, 4): 9,
          (0, 3): 7,
          (1, 0): 50,
          (3, 0): 8,
          (0, 2): 35,
          (1, 2): 1}),
 122)

In [28]:
count_transitions(cass_T_Fam_pruned, cell_site_info_1, False)

input tree has 2798 nodes and 1460 leaves
modified tree has 1460 leaves


(Counter({(0, 1): 15,
          (0, 4): 9,
          (0, 3): 7,
          (1, 0): 57,
          (3, 0): 11,
          (0, 2): 36,
          (1, 2): 1}),
 136)

## Startle (Published)

In [29]:
count_transitions(linTracer_T_Fam, cell_site_info_1, True)

input tree has 21816 nodes and 21108 leaves
removing 20317
modified tree has 1499 nodes
modified tree has 791 leaves


(Counter({(0, 1): 13,
          (1, 0): 40,
          (0, 3): 13,
          (0, 4): 10,
          (0, 2): 66,
          (1, 3): 1,
          (1, 4): 2,
          (3, 0): 6,
          (1, 2): 1,
          (3, 1): 1,
          (3, 2): 2,
          (3, 4): 1}),
 156)

In [30]:
count_transitions(linTracer_T_Fam, cell_site_info_1, False)

input tree has 21816 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 16,
          (1, 0): 356,
          (0, 4): 84,
          (0, 3): 15,
          (0, 2): 1260,
          (1, 2): 2,
          (1, 3): 1,
          (1, 4): 2,
          (3, 0): 24,
          (2, 0): 14,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 4): 1}),
 1783)

In [31]:
count_transitions(startle_pars_resolved, cell_site_info_1, True)

input tree has 42487 nodes and 21108 leaves
removing 578
modified tree has 41909 nodes
modified tree has 20530 leaves


(Counter({(0, 1): 15,
          (0, 4): 10,
          (0, 3): 16,
          (0, 2): 66,
          (1, 0): 38,
          (1, 4): 2,
          (3, 1): 1,
          (3, 2): 2,
          (1, 2): 1,
          (3, 0): 3,
          (3, 4): 1,
          (1, 3): 1}),
 156)

In [32]:
count_transitions(startle_pars_resolved, cell_site_info_1, False)

input tree has 42487 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 15,
          (0, 4): 10,
          (0, 3): 16,
          (0, 2): 66,
          (1, 0): 38,
          (1, 4): 2,
          (3, 1): 1,
          (3, 2): 2,
          (1, 2): 1,
          (3, 0): 3,
          (3, 4): 1,
          (1, 3): 1}),
 156)

In [33]:
count_transitions(linTracer_T_Fam_pruned, cell_site_info_1, True)

input tree has 1871 nodes and 1460 leaves
removing 1013
modified tree has 858 nodes
modified tree has 447 leaves


(Counter({(0, 1): 9,
          (1, 0): 33,
          (0, 4): 8,
          (0, 3): 9,
          (0, 2): 17,
          (1, 2): 1,
          (3, 0): 5}),
 82)

In [34]:
count_transitions(linTracer_T_Fam_pruned, cell_site_info_1, False)

input tree has 1871 nodes and 1460 leaves
modified tree has 1460 leaves


(Counter({(0, 1): 9,
          (1, 0): 48,
          (0, 4): 9,
          (0, 3): 9,
          (0, 2): 35,
          (1, 2): 1,
          (3, 0): 8}),
 119)

## Problin-startle

In [35]:
count_transitions(problin_startle, cell_site_info_1, True)

input tree has 2919 nodes and 1460 leaves
removing 302
modified tree has 2617 nodes
modified tree has 1158 leaves


(Counter({(0, 4): 9,
          (0, 1): 15,
          (0, 3): 11,
          (1, 0): 37,
          (0, 2): 19,
          (3, 0): 4,
          (2, 0): 3,
          (1, 2): 1}),
 99)

In [36]:
count_transitions(problin_startle, cell_site_info_1, False)

input tree has 2919 nodes and 1460 leaves
modified tree has 1460 leaves


(Counter({(0, 4): 9,
          (0, 1): 15,
          (0, 3): 11,
          (1, 0): 37,
          (0, 2): 19,
          (3, 0): 4,
          (2, 0): 3,
          (1, 2): 1}),
 99)

In [37]:
count_transitions(problin_startle_full_tree, cell_site_info_1, True)

input tree has 24026 nodes and 21107 leaves
removing 19548
modified tree has 4478 nodes
modified tree has 1559 leaves


(Counter({(0, 4): 15,
          (0, 1): 23,
          (0, 3): 15,
          (1, 0): 43,
          (1, 2): 2,
          (0, 2): 65,
          (3, 1): 1,
          (3, 2): 2,
          (3, 0): 4,
          (2, 0): 5,
          (1, 4): 2,
          (3, 4): 1,
          (1, 3): 1}),
 179)

In [38]:
count_transitions(problin_startle_full_tree, cell_site_info_1, False)

input tree has 24026 nodes and 21107 leaves
modified tree has 21107 leaves


(Counter({(0, 4): 72,
          (0, 1): 30,
          (0, 3): 17,
          (1, 0): 150,
          (1, 2): 2,
          (0, 2): 1244,
          (2, 0): 18,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 0): 20,
          (1, 4): 2,
          (3, 4): 1,
          (4, 0): 3,
          (1, 3): 1}),
 1568)

In [39]:
count_transitions(problin_startle_full_tree_pars_resolved, cell_site_info_1, True)

input tree has 43282 nodes and 21107 leaves
removing 533
modified tree has 42749 nodes
modified tree has 20574 leaves


(Counter({(0, 4): 15,
          (0, 1): 23,
          (0, 3): 15,
          (1, 0): 43,
          (1, 2): 2,
          (0, 2): 65,
          (3, 1): 1,
          (3, 2): 2,
          (3, 0): 4,
          (2, 0): 5,
          (1, 4): 2,
          (3, 4): 1,
          (1, 3): 1}),
 179)

In [40]:
count_transitions(problin_startle_full_tree_pars_resolved, cell_site_info_1, False)

input tree has 43282 nodes and 21107 leaves
modified tree has 21107 leaves


(Counter({(0, 4): 15,
          (0, 1): 23,
          (0, 3): 15,
          (1, 0): 43,
          (1, 2): 2,
          (0, 2): 65,
          (3, 1): 1,
          (3, 2): 2,
          (3, 0): 4,
          (2, 0): 5,
          (1, 4): 2,
          (3, 4): 1,
          (1, 3): 1}),
 179)

## Startle (Pruned to same 1460)

In [43]:
count_transitions(linTracer_T_Fam_pruned, cell_site_info_1, True)

input tree has 1871 nodes and 1460 leaves
removing 1013
modified tree has 858 nodes
modified tree has 447 leaves


(Counter({(0, 1): 9,
          (1, 0): 33,
          (0, 4): 8,
          (0, 3): 9,
          (0, 2): 17,
          (1, 2): 1,
          (3, 0): 5}),
 82)

In [44]:
count_transitions(linTracer_T_Fam_pruned, cell_site_info_1, False)

input tree has 1871 nodes and 1460 leaves
modified tree has 1460 leaves


(Counter({(0, 1): 9,
          (1, 0): 48,
          (0, 4): 9,
          (0, 3): 9,
          (0, 2): 35,
          (1, 2): 1,
          (3, 0): 8}),
 119)