In [29]:
import pandas as pd
import sys
import argparse
import itertools
import math
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import cassiopeia as cass
import os
import json
import networkx as nx
import itertools
import ete3

import itertools
from collections import Counter
from collections import defaultdict
from collections import deque
from Bio import Phylo

import subprocess

from utilities import *
%load_ext autoreload
%autoreload 2

#import gurobipy as gp

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Helper Functions

In [19]:
def get_smaller_tree(tree, cell_site_info):

    mod_tree = tree.copy()
    leaves_to_remove = []
    for node in mod_tree.nodes:
        leaf_site_list = []
        
        for child in mod_tree[node]:
            if is_leaf(mod_tree, child):
                leaf_site = cell_site_info[child]
                if leaf_site not in leaf_site_list:
                    leaf_site_list.append(leaf_site)
                else:
                    leaves_to_remove.append(child)
    
    for leaf in leaves_to_remove:
        mod_tree.remove_node(leaf)
    print(f"removing {len(leaves_to_remove)}")
    return mod_tree

In [20]:
def update_parsimonious_labeling_counts(T, cell_site_info, state_list, cost_labeling, count_labeling, node):
    if is_leaf(T, node):
        for state_idx, state in enumerate(state_list):
            if state == cell_site_info[node]:
                cost_labeling[node][state_idx] = 0
                count_labeling[node][state_idx] = 1
            else:
                cost_labeling[node][state_idx] = np.inf
                count_labeling[node][state_idx] = 0
    else:
        for child in T[node]:
            update_parsimonious_labeling_counts(T, cell_site_info, state_list, cost_labeling, count_labeling, child)
        
        for parent_state_idx, state in enumerate(state_list):
            cost_labeling[node][parent_state_idx] = 0
            count_labeling[node][parent_state_idx] = 1
            for child in T[node]:
                min_cost = np.inf
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[child][child_state_idx]
                    else:
                        curr_cost = cost_labeling[child][child_state_idx] + 1

                    if curr_cost < min_cost:
                        min_cost = curr_cost
                        min_child_state_idx = child_state_idx

                cost_labeling[node][parent_state_idx] += min_cost

                curr_child_total_count = 0
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[child][child_state_idx]
                    else:
                        curr_cost = cost_labeling[child][child_state_idx] + 1

                    if curr_cost == min_cost:
                        curr_child_total_count += count_labeling[child][child_state_idx]

                count_labeling[node][parent_state_idx] *= curr_child_total_count


In [21]:
def count_parsimonious_labelings(T, cell_site_info):
    state_list = list(cell_site_info.unique())
    nstates = len(state_list)
    # state_transition_counter = Counter()
    node_cost_labeling = {node:[None]*nstates for node in T.nodes}
    # print(T.nodes)
    # print(node_cost_labeling['root'])
    node_count_labeling = {node:[None]*nstates for node in T.nodes}
    update_parsimonious_labeling_counts(T, cell_site_info, state_list, node_cost_labeling, node_count_labeling, 'root')
    
    return node_cost_labeling, node_count_labeling

In [22]:
def computeG(T, cell_site_info, state_list, cost_labeling, presence_labeling, node):
    min_cost = min(cost_labeling[node])
    if node == 'root':
         for state_idx in range(len(state_list)):
                if cost_labeling[node][state_idx] == min_cost:
                    presence_labeling[node][state_idx] = 1
                else:
                    presence_labeling[node][state_idx] = 0
    else:
        
        parent_node = list(T.pred[node])[0]
        
        for state_idx in range(len(state_list)):
            presence_labeling[node][state_idx] = 0
        
        for parent_state_idx in range(len(state_list)):
            if presence_labeling[parent_node][parent_state_idx] == 1:

                min_cost = np.inf
                for child_state_idx, child_state in enumerate(state_list):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[node][child_state_idx]
                    else:
                        curr_cost = cost_labeling[node][child_state_idx] + 1

                    if curr_cost < min_cost:
                        min_cost = curr_cost
        
                for child_state_idx in range(len(state_list)):
                    if child_state_idx == parent_state_idx:
                        curr_cost = cost_labeling[node][child_state_idx]
                    else:
                        curr_cost = cost_labeling[node][child_state_idx] + 1
                        
                    if curr_cost == min_cost:
                        presence_labeling[node][child_state_idx] = 1
                

    for child in T[node]:
        computeG(T, cell_site_info, state_list, cost_labeling, presence_labeling, child)

# Read in Inputs

In [82]:
df_kp_meta = pd.read_csv(f'KPTracer_meta.csv', index_col = 0)

In [83]:
df_character_matrix_Fam_1 = pd.read_csv('3724_NT_All_character_matrix.txt',
                                        index_col = 0, sep='\t', dtype=str)
df_character_matrix_Fam_1 = df_character_matrix_Fam_1.replace('-', '-1')

In [84]:
cell_site_info_1 = df_kp_meta.loc[df_character_matrix_Fam_1.index]['SubTumor']
for cell, site in cell_site_info_1.items():
    # if site.startswith('3513_NT_T'):
        # cell_site_info[cell] = 'primary'
    if len(site.split('_')) > 3:
        cell_site_info_1[cell] = '_'.join(site.split('_')[:-1])

In [36]:
# the published cassiopeia tree, 21,108 nodes
cassh_21108 = from_newick_get_nx_tree('../published/casshybrid.nwk')

# the published startle tree, 21,108 nodes
startle_21108 = from_newick_get_nx_tree('../published/startle.nwk')

# the full problin tree, 21,108 nodes
problin_21108 = from_newick_get_nx_tree('../pruned_resolved/problin_pruned_resolved_21108.nwk')

In [52]:
# the deduplicated cassiopeia tree, 1461 nodes
cassh_1461 = from_newick_get_nx_tree('../pruned_1461/casshybrid_1461.nwk')

# the deduplicated startle tree, 1461 nodes
startle_1461 = from_newick_get_nx_tree('../pruned_1461/startle_1461.nwk')

# the deduplicated problin tree, 1461 nodes
problin_1461 = from_newick_get_nx_tree('../pruned_1461/problin_1461.nwk')

In [54]:
# the cassiopeia tree with all identical sequences placed as polytomies
cassh_21108_replaced = from_newick_get_nx_tree('../identical_polytomies/casshybrid_pruned_resolved_21108.nwk')

# the startle tree with all identical sequences placed as polytomies
startle_21108_replaced = from_newick_get_nx_tree('../identical_polytomies/startle_pruned_resolved_21108.nwk')


# Set up Functions

In [44]:

def count_transitions(input_tree, cell_site_info, throw_out_identical_sites=False):
    num_leaves = 0
    for n in input_tree.nodes:
        if is_leaf(input_tree, n): 
            num_leaves += 1
    
    print(f"input tree has {len(input_tree.nodes)} nodes and {num_leaves} leaves")
    if throw_out_identical_sites:
        input_tree = get_smaller_tree(input_tree, cell_site_info_1)
        print(f"modified tree has {len(input_tree.nodes)} nodes") 
    
    num_leaves = 0
    for n in input_tree.nodes:
        if is_leaf(input_tree, n): 
            num_leaves += 1
    print(f"modified tree has {num_leaves} leaves")
    cost_labeling, count_labeling = count_parsimonious_labelings(input_tree, cell_site_info)
    state_list = list(cell_site_info.unique())
    nstates = len(state_list)
    presence_labeling = {node:[None]*nstates for node in input_tree.nodes}
    computeG(input_tree, cell_site_info, state_list, cost_labeling, presence_labeling, 'root')

    optimal_labeling = {node:'' for node in input_tree.nodes}

    for node in input_tree.nodes:
        if presence_labeling[node][0] == 1:
            optimal_labeling[node] = 0
        else:
            optimal_labeling[node] = np.where(presence_labeling[node])[0][0]

    num_transitions = Counter()
    for edge in input_tree.edges:
        if optimal_labeling[edge[0]] != optimal_labeling[edge[1]]:
            num_transitions[(optimal_labeling[edge[0]], optimal_labeling[edge[1]])] += 1

    return num_transitions, sum(num_transitions.values())


# Full Trees Polytomies

In [45]:
# migration cost, cassiopeia

count_transitions(cassh_21108, cell_site_info_1)

input tree has 22748 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 19,
          (0, 4): 71,
          (0, 3): 13,
          (1, 0): 175,
          (0, 2): 1261,
          (1, 4): 2,
          (3, 0): 28,
          (1, 2): 2,
          (1, 3): 1,
          (2, 0): 15,
          (2, 4): 1,
          (3, 2): 6,
          (3, 4): 1,
          (4, 0): 4,
          (3, 1): 1}),
 1600)

In [46]:
# migration cost, startle

count_transitions(startle_21108, cell_site_info_1)

input tree has 21816 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 16,
          (1, 0): 356,
          (0, 4): 84,
          (0, 3): 15,
          (0, 2): 1260,
          (1, 2): 2,
          (1, 3): 1,
          (1, 4): 2,
          (3, 0): 24,
          (2, 0): 14,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 4): 1}),
 1783)

In [48]:
# migration cost, problin

count_transitions(problin_21108, cell_site_info_1)

input tree has 24029 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 4): 72,
          (0, 1): 30,
          (0, 3): 17,
          (1, 0): 150,
          (1, 2): 2,
          (0, 2): 1244,
          (2, 0): 18,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 0): 20,
          (1, 4): 2,
          (3, 4): 1,
          (4, 0): 3,
          (1, 3): 1}),
 1568)

# Deduplicated to 1461 Leaves

In [49]:
# migration cost, cassiopeia

count_transitions(cassh_1461, cell_site_info_1)

input tree has 2800 nodes and 1461 leaves
modified tree has 1461 leaves


(Counter({(0, 1): 15,
          (0, 4): 9,
          (0, 3): 7,
          (1, 0): 57,
          (3, 0): 11,
          (0, 2): 36,
          (1, 2): 1}),
 136)

In [50]:
# migration cost, startle

count_transitions(startle_1461, cell_site_info_1, False)

input tree has 1872 nodes and 1461 leaves
modified tree has 1461 leaves


(Counter({(0, 1): 9,
          (1, 0): 48,
          (0, 4): 9,
          (0, 3): 9,
          (0, 2): 35,
          (1, 2): 1,
          (3, 0): 8}),
 119)

In [51]:
# migration cost, problin


count_transitions(problin_1461, cell_site_info_1, False)

input tree has 2921 nodes and 1461 leaves
modified tree has 1461 leaves


(Counter({(0, 4): 9,
          (0, 1): 15,
          (0, 3): 11,
          (1, 0): 37,
          (0, 2): 19,
          (3, 0): 4,
          (2, 0): 3,
          (1, 2): 1}),
 99)

# Deduplicated then replace identical sequences as polytomies

In [55]:
# migration cost, cassiopeia

count_transitions(cassh_21108_replaced, cell_site_info_1)

input tree has 23908 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 21,
          (0, 4): 71,
          (0, 3): 13,
          (1, 0): 173,
          (3, 0): 28,
          (0, 2): 1261,
          (1, 4): 2,
          (1, 3): 1,
          (1, 2): 2,
          (2, 4): 1,
          (2, 0): 14,
          (4, 0): 4,
          (3, 4): 1,
          (3, 2): 6,
          (3, 1): 1}),
 1599)

In [85]:
# migration cost, startle

count_transitions(startle_21108_replaced, cell_site_info_1)

input tree has 22980 nodes and 21108 leaves
modified tree has 21108 leaves


(Counter({(0, 1): 21,
          (1, 0): 162,
          (0, 4): 72,
          (0, 3): 15,
          (0, 2): 1260,
          (1, 2): 2,
          (1, 3): 1,
          (1, 4): 2,
          (3, 0): 24,
          (2, 0): 14,
          (2, 4): 1,
          (3, 1): 1,
          (3, 2): 6,
          (3, 4): 1,
          (4, 0): 3}),
 1585)