# Imports

In [6]:
import numpy as np
import pandas as pd
import networkx as nx
import ete3

# Helper functions for simulation

In [7]:
def simulate_imbalanced_tree(num_init_cells, init_death_prob = 0.1, init_repr_prob = 0.75, cancer_prob = 1e-3, tree_depth=15):
    num_cells = num_init_cells
    death_probs = [init_death_prob]*num_cells
    repr_probs = [init_repr_prob]*num_cells


    init_cells = [str(i) for i in np.arange(num_cells)]
    parent_ix = []
    cell_names = [np.array(init_cells)]
    repetition_coefs_list = []

    for i in range(tree_depth):
        dying = np.random.random(num_cells) < death_probs
        mutating = np.random.random(num_cells) < repr_probs

        repetition_coefs = (mutating+1)*(1-dying)

        repetition_coefs_list.append(repetition_coefs)
        next_gen = np.repeat(init_cells, repetition_coefs)
        
        if len(next_gen) == 0:
            raise Exception('No cells left to replicate. Terminate simulation.')
        # Label generation in terms of binary split with parents
        binary_labels = [next_gen[i]+'1' if next_gen[i-1]==next_gen[i] else next_gen[i]+'0' for i in range(1, len(next_gen))]
        binary_labels = [next_gen[0]+'0'] + binary_labels
        cell_names.append(np.array(binary_labels))
        parent_ix.append(np.repeat(np.arange(num_cells), repetition_coefs))

        death_probs = np.repeat(death_probs , repetition_coefs)
        repr_probs = np.repeat(repr_probs , repetition_coefs)

        num_cells = sum(repetition_coefs)
        
        init_cells = binary_labels

        # Introduce cancerous mutations which may increase tumour fitness
        has_cancer = np.random.random(num_cells) < cancer_prob
        death_probs[has_cancer] -= 1e-2
        repr_probs[has_cancer] += 1e-2

    return cell_names


In [8]:
def generate_cassettes(num_cells, cassette_size, mutation_probs=None, deletion_probs=None):
    """
    CRISPR edits are independent of each other. Each edit site has some probability of being modified or being deleted
    The only dependence between cells is that a site which has already been edited or deleted cannot be edited or deleted again

    We can a priori simulate CRISPR edits for a given number of cells and then attach these edits to cells in the lineage 
    independently of one another as vector operations for speed up

    @param: num_cells (int) - Number of cells for which we should simulate CRISPR recording cassettes
    @param: cassette_size (int) - Number of sites where CRISPR edits can occur
    """
    
    print('THIS IS A PLACEHOLDER FOR THE REAL CASSETTE FUNCTION. CHANGE THIS!')
    
    cassette_edits = np.random.randint(low=-1, high=100, size=(total_internal_nodes, cassette_size))

    indices = np.random.choice(np.arange(cassette_edits.size), replace=False,
                               size=int(cassette_edits.size * 0.8))
    cassette_edits[np.unravel_index(indices, cassette_edits.shape)]= 0
    return (cassette_edits.tolist())


In [9]:
def graph_to_ete3(G):
    """
    Convert networkx DiGraph to ete3 tree
    """
    import ete3, itertools
    root = "ROOT"
    subtrees = {node:ete3.Tree(name=node) for node in G.nodes()}
    [*map(lambda edge:subtrees[edge[0]].add_child(subtrees[edge[1]]), G.edges())]
    tree = subtrees[root]
    return tree


In [14]:
def generate_lineage(cell_names, cassette_edits):
    """
    Given binary names of cells in a subsample tree and corresponding CRISPR edits for each cell,
    construct a networkx graph representing parental lineages and accumulated CRISPR edits.
    """
    lineage = nx.DiGraph()
    cassette_size = len(cassette_edits[0])
    blank_cassette = [0]*cassette_size
    for generation, cells in enumerate(cell_names):
        for cell in cells:

            lineage.add_node(cell, generation=generation, cassette_state=np.array(blank_cassette), crispr_edit=blank_cassette)
            if len(cell) == 1:
                # Then this cell has no real parent
                lineage.add_edge("ROOT",cell)
                crispr_edit = np.array(cassette_edits.pop())
                lineage.nodes[cell]['crispr_edit'] = crispr_edit
                lineage.nodes[cell]['cassette_state'] = crispr_edit
                continue
            # Add an adge between parent node and recently added child
            parent = cell[:-1]
            lineage.add_edge(parent,cell)


            # Sample a crispr edit for this cell
            crispr_edit = np.array(cassette_edits.pop())
            # Sites which are already edited in lineage are forbidden to be edited again 
            parent_state = lineage.nodes()[parent]['cassette_state']
            crispr_edit[parent_state!=0]=0

            lineage.nodes[cell]['crispr_edit'] = crispr_edit
            lineage.nodes[cell]['cassette_state'] = crispr_edit+parent_state


    return lineage

# Run lineage tracing simulation

In [15]:
cell_names = simulate_imbalanced_tree(num_init_cells=2, init_death_prob = 0.1, init_repr_prob = 0.75, cancer_prob = 1e-3, tree_depth=5)

# Generate cassette edits for each internal node
total_internal_nodes = sum([len(x) for x in cell_names])
cassette_edits = generate_cassettes(total_internal_nodes, cassette_size=5)

# Generate networkx lineage object
lineage = generate_lineage(cell_names, cassette_edits)
tree = graph_to_ete3(lineage)


print(tree.get_ascii())
    

THIS IS A PLACEHOLD FOR THE REAL CASSETTE FUNCTION. CHANGE THIS!

                 /-00000
             /0000
            |   |     /-000010
            |    \00001
          /000        \-000011
         |  |
         |  |         /-000100
         |   \000100010
         |            \-000101
         |
       /00            /-001000
      |  |       /00100
      |  |      |     \-001001
      |  |   /0010
      |  |  |   |     /-001010
      |  |  |    \00101
      |   \001        \-001011
      |     |
      |     |         /-001100
    /0|     |    /00110
   |  |      \0011    \-001101
   |  |         |
   |  |          \-00111
   |  |
   |  |               /-010000
   |  |          /01000
   |  |   /0100100    \-010001
   |  |  |      |
   |  |  |       \01001-010010
   |   \01
   |     |            /-011000
   |     |   /011001100
   |     |  |         \-011001
   |      \011
-ROOT       |         /-011100
   |        |    /01110
   |         \0111    \-011101
   |            |


# Manipulating lineage graphs

In [None]:
def get_character_matrix(lineage):
    """
    Return pd.DataFrame containing character matrix - maps cell IDs of leaf nodes to CRISPR edits. 
    """
    
    raise NotImplementedError
    
def drop_missing_data(lineage, missing_fraction):
    """
    Generate a copy of the lineage graph with missing data (represented as '-' in character matrix)
    """
    
    raise NotImplementedError
    
def save(lineage):
    """
    Save a copy of lineage networkx DiGraph is some easily read form
    """
    
    raise NotImplementedError
    
    

# Benchmark against Cassiopeia

Cassiopeia should be able to read in both networkx trees and ete3 trees. See if you can generate some imbalanced trees, quantify how imbalanced they are (some sort of distribution of tree depths or something like that?) and record the performance of NJ vs Cassiopeia vs Scelestical

Also test these as a function of missing data 