### Goal

Analyze the output of pastml to identify:  
- convergent clusters: CNE state = 1 (present) only at tips (species). CNE state = 0 (absent) at all nodes of the tree.
- True conserved elements + potential convergence: CNE state = 1 at some node(s) but 0 at the common ancestor of all species that possess the CNE.
- True conserved elements: CNE state = 1 at some common ancestor of all species that possess the CNE.

### Input

- pastml_output_dict.txt : output of parsimony_analysis_part1.py
- phylogenetic tree represented as dictionary of parent-child relationships

### Output

- convergent_clusters.pickle : dictionary of convergent clusters
- cluster_lcas.pickle : dictionary of last ocmmon ancestor of each cluster

In [1]:
from pastml.acr import pastml_pipeline
import csv
import itertools
import pandas as pd
import sys
import pickle
import ast
from collections import Counter
import csv
from collections import defaultdict

#### Read dictionary

In [2]:
with open("pastml_output_dict.txt", "r") as data:
    dictionary = ast.literal_eval(data.read())

#### Pickle dictionary


In [3]:
with open('pastml_output_dict.pickle', 'wb') as handle:
    pickle.dump(dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Retrieve pickled dictionary

In [5]:
#with open('pastml_output_dict.pickle', 'rb') as handle:
#    dictionary = pickle.load(handle)

#### Species list

In [4]:
species_list = ['dgig',
                'ofav',
                'pdam',
                'spis',
                'adig', 
                'nvec',
                'epal',
                'aten', 
                'mvir',
                'aaur', 
                'chem',
                'hvul',
                'hsym'                                                                                       
               ]

#### Dictionary of parent-child relationships ("tree")

All species and nodes must be represented

In [5]:
tax_dict = {
    'aaur': 'acraspeda',
    'adig': 'scleractinia',
    'aten': 'enthemonae',
    'chem': 'leptothecata',
    'dgig': 'anthozoa',
    'epal': 'enthemonae',
    'hsym': 'leptothecata',
    'hvul': 'hydrozoa',
    'mvir': 'acraspeda',
    'nvec': 'actiniaria',
    'ofav': 'robusta',
    'pdam': 'pocilloporidae',
    'spis': 'pocilloporidae',
    'pocilloporidae': 'robusta',
    'robusta': 'scleractinia',
    'scleractinia': 'hexacorallia',
    'enthemonae': 'actiniaria',
    'actiniaria': 'hexacorallia',
    'hexacorallia': 'anthozoa',
    'anthozoa': 'cnidaria',
    'hydractinia': 'anthoathecata',
    'leptothecata': 'hydrozoa',
    'hydrozoa': 'medusozoa',
    'acraspeda': 'medusozoa',
    'medusozoa': 'cnidaria'
}

#### Parse character state dictionary

In [6]:
cluster_nodes = {}
for node_dict in dictionary['nodes']:
    for a, b in node_dict.items():
        #print(b)
        #if isinstance(b, type(dict)): #and 'node_name' in b:
        if 'node_name' in b:
            node_name = b['node_name']
            cluster_states = b['tooltip'].split("<br>")
            for cluster in cluster_states:
                cluster_id = cluster.split(":")[0]
                state = cluster.split(":")[1].replace(" ", "")
                if cluster_id in cluster_nodes:
                    cluster_nodes[cluster_id][node_name] = state
                else:
                    cluster_nodes[cluster_id] = {}
                    cluster_nodes[cluster_id][node_name] = state
#cluster_nodes

In [7]:
cluster_nodes['cluster_10']

{'aaur': '0',
 'acraspeda': '0',
 'actiniaria': '0',
 'adig': '1',
 'anthozoa': '0',
 'aten': '0',
 'chem': '0',
 'cnidaria': '0',
 'dgig': '0',
 'enthemonae': '0',
 'epal': '0',
 'hexacorallia': '0',
 'hsym': '0',
 'hvul': '0',
 'hydrozoa': '0',
 'leptothecata': '0',
 'medusozoa': '0',
 'mvir': '0',
 'nvec': '0',
 'ofav': '0',
 'pdam': '1',
 'pocilloporidae': '1',
 'robusta': '0or1',
 'scleractinia': '0or1',
 'spis': '1'}

#### Retrieve clusters where state is 0 for all nodes (only 1 at the tips)

In [8]:
convergent_clusters = [] # clusters where CNEs appear only at the tips
conserved_clusters = [] # clusters that appear at some deeper node
conserved_clust_nodes = {} # {cluster1:{species:[species_a, species_b, etc.], nodes:[node_x, node_y]}, cluster_2...}
conserved_clust_species = {}
for cluster, node_states in cluster_nodes.items():
    #print(cluster)
    node_count = 0
    for node, state in node_states.items():
        #print("node:",node, "state:", state)
        if node not in species_list and (state == '1' or state == '0or1'):
            node_count+=1
            conserved_clust_nodes.setdefault(cluster, []).append(node)   # [cluster][nodes].append(node)
        if node in species_list and state == '1':
            conserved_clust_species.setdefault(cluster, []).append(node) 
    if node_count == 0:
        convergent_clusters.append(cluster)
    else:
        conserved_clusters.append(cluster)

In [9]:
conserved_clust_nodes # nodes for which cluster state is 1 or '0 or 1'

{'cluster_1': ['cnidaria',
  'anthozoa',
  'medusozoa',
  'hexacorallia',
  'acraspeda',
  'hydrozoa',
  'scleractinia',
  'actiniaria',
  'leptothecata',
  'robusta',
  'enthemonae',
  'pocilloporidae'],
 'cluster_10': ['scleractinia', 'robusta', 'pocilloporidae'],
 'cluster_100': ['scleractinia', 'robusta', 'pocilloporidae'],
 'cluster_1000': ['scleractinia', 'robusta', 'pocilloporidae'],
 'cluster_10000': ['hexacorallia',
  'scleractinia',
  'robusta',
  'pocilloporidae'],
 'cluster_10001': ['hexacorallia', 'scleractinia', 'robusta'],
 'cluster_10002': ['hexacorallia', 'scleractinia', 'robusta'],
 'cluster_10003': ['hexacorallia',
  'scleractinia',
  'robusta',
  'pocilloporidae'],
 'cluster_10004': ['scleractinia', 'robusta', 'pocilloporidae'],
 'cluster_10005': ['hexacorallia', 'scleractinia', 'robusta'],
 'cluster_10006': ['hexacorallia', 'scleractinia', 'robusta', 'enthemonae'],
 'cluster_10007': ['scleractinia', 'robusta', 'pocilloporidae'],
 'cluster_10008': ['hexacorallia', '

In [10]:
conserved_clust_nodes['cluster_23']

['hexacorallia',
 'acraspeda',
 'scleractinia',
 'robusta',
 'enthemonae',
 'pocilloporidae']

In [11]:
conserved_clust_species # species for which state is 1

{'cluster_1': ['dgig',
  'mvir',
  'aaur',
  'hvul',
  'adig',
  'nvec',
  'chem',
  'ofav',
  'epal',
  'aten',
  'pdam',
  'spis'],
 'cluster_10': ['adig', 'pdam', 'spis'],
 'cluster_100': ['adig', 'ofav', 'pdam', 'spis'],
 'cluster_1000': ['adig', 'pdam', 'spis'],
 'cluster_10000': ['adig', 'ofav', 'spis'],
 'cluster_10001': ['adig', 'ofav'],
 'cluster_10002': ['adig', 'ofav'],
 'cluster_10003': ['adig', 'ofav', 'spis'],
 'cluster_10004': ['adig', 'ofav', 'pdam', 'spis'],
 'cluster_10005': ['adig', 'ofav'],
 'cluster_10006': ['adig', 'ofav', 'epal'],
 'cluster_10007': ['adig', 'ofav', 'pdam', 'spis'],
 'cluster_10008': ['adig', 'ofav'],
 'cluster_10009': ['adig', 'ofav'],
 'cluster_1001': ['adig', 'pdam', 'spis'],
 'cluster_10010': ['adig', 'ofav'],
 'cluster_10011': ['adig', 'ofav'],
 'cluster_10012': ['adig', 'ofav'],
 'cluster_10013': ['adig', 'ofav'],
 'cluster_10014': ['adig', 'ofav'],
 'cluster_10015': ['adig', 'ofav'],
 'cluster_10016': ['adig', 'ofav'],
 'cluster_10017': ['a

#### For each cluster:
- find common ancestor of all species
- check state at that node
- if node is 0, cluster is considered convergent

#### Function that retrieves all common ancestors of a set of species

In [12]:
def find_common_ancestors(sp_list, tree):
    result = None 
    all_ancestors = []
    for species in sp_list:
        ancestor = species
        ancestors = {species}
        while True:
            try:
                ancestor = tree[ancestor]  # get the species' ancestor
                ancestors.add(ancestor)  # store it in the ancestors set
            except KeyError:
                all_ancestors.append(ancestors) # add all ancestors of species to all_ancestors_list
                break
    # Now all_ancestors is a list of all sets of ancestors (one set per species)
    common = set.intersection(*all_ancestors)
    return common

#### Function that determines which common ancestor is the most recent

Run with ancestor set and original ancestor set as same set of ancestors. One set will change during recursion, the other will remain unchanged.  
To do: find a more elegant way to do this.

In [13]:
def most_recent_ancestor(ancestor_set, sp_list, tax_dict, original_ancestor_set):
    original_ancestor_set = original_ancestor_set.copy()
    for ancestor in ancestor_set.copy(): # copy to avoid changing length of set during iteration
        # create list of all parents of that ancestor in case none of them are in ancestor set
        parent_set = set()
        for parent, child in tax_dict.items():
            if child == ancestor: # Find parent of current ancestor
                parent_set.add(parent)
                if parent in ancestor_set and parent not in sp_list: # Jump to next parent in ancestor set
                    ancestor_set.discard(child)
                    if len(ancestor_set) == 1: # job done
                        return(parent)
                    else:
                        most_recent_ancestor(ancestor_set, sp_list, tax_dict, original_ancestor_set) # run function again with reduced ancestor set
                elif parent in sp_list:
                    return(child)
        if not parent_set & original_ancestor_set:
            return(ancestor)

#### Retrieve most recent common ancestor of all clusters
At the same time, identify clusters where common ancestor of all species does not have cluster (one of the species acquired convergently)


In [14]:
len(conserved_clust_species)

30189

In [15]:
cluster_lcas = {}
additional_convergent_clusters = []
for cluster, sp_list in conserved_clust_species.items():
    #if cluster in conserved_clusters:
        common_ancestors = find_common_ancestors(sp_list, tax_dict)
        lca = most_recent_ancestor(common_ancestors, sp_list, tax_dict, common_ancestors)
        if cluster in conserved_clust_nodes and lca not in conserved_clust_nodes[cluster]:
            additional_convergent_clusters.append(cluster)
        elif cluster not in convergent_clusters: 
            cluster_lcas[cluster] = lca

In [17]:
len(cluster_lcas)

18140

In [15]:
'cluster_23' in additional_convergent_clusters

True

#### Count losses

In [17]:
cluster_lcas

{'cluster_1': 'cnidaria',
 'cluster_10': 'scleractinia',
 'cluster_100': 'scleractinia',
 'cluster_1000': 'scleractinia',
 'cluster_10000': 'scleractinia',
 'cluster_10001': 'scleractinia',
 'cluster_10002': 'scleractinia',
 'cluster_10003': 'scleractinia',
 'cluster_10004': 'scleractinia',
 'cluster_10005': 'scleractinia',
 'cluster_10006': 'hexacorallia',
 'cluster_10007': 'scleractinia',
 'cluster_10008': 'scleractinia',
 'cluster_10009': 'scleractinia',
 'cluster_1001': 'scleractinia',
 'cluster_10010': 'scleractinia',
 'cluster_10011': 'scleractinia',
 'cluster_10012': 'scleractinia',
 'cluster_10013': 'scleractinia',
 'cluster_10014': 'scleractinia',
 'cluster_10015': 'scleractinia',
 'cluster_10016': 'scleractinia',
 'cluster_10017': 'scleractinia',
 'cluster_10018': 'scleractinia',
 'cluster_10019': 'scleractinia',
 'cluster_1002': 'scleractinia',
 'cluster_10020': 'scleractinia',
 'cluster_10021': 'scleractinia',
 'cluster_10022': 'scleractinia',
 'cluster_10023': 'scleractini

In [18]:
def daughter_species(node, daughter_list):
    #print("node: ", node)
    #print("daughter_list: ", daughter_list)
    for child, parent in tax_dict.items():
        if parent == node and child not in species_list:
            #print("found parent: ", parent, "child is: ", child)
            new_parent = child
            #print("run function again with ", new_parent, daughter_list)
            daughter_species(new_parent, daughter_list)
        if parent == node and child in species_list:
            #print("FOund parent, child is species: ", child)
            daughter_list.append(child)
            #print("new_daughter_list is:", daughter_list)
    #print("return daughter_list:", daughter_list)
    return(daughter_list)

In [19]:
cne_losses = {}
for cluster, lca in cluster_lcas.items():
    lca_daughters = daughter_species(lca, [])
    for species in lca_daughters:
        if species not in conserved_clust_species[cluster]:
            #print("cne lost in ", species)
            cne_losses.setdefault(cluster, []).append(species)
cne_losses

{'cluster_1': ['hsym'],
 'cluster_10': ['ofav'],
 'cluster_1000': ['ofav'],
 'cluster_10000': ['pdam'],
 'cluster_10001': ['pdam', 'spis'],
 'cluster_10002': ['pdam', 'spis'],
 'cluster_10003': ['pdam'],
 'cluster_10005': ['pdam', 'spis'],
 'cluster_10006': ['pdam', 'spis', 'nvec', 'aten'],
 'cluster_10008': ['pdam', 'spis'],
 'cluster_10009': ['pdam', 'spis'],
 'cluster_1001': ['ofav'],
 'cluster_10010': ['pdam', 'spis'],
 'cluster_10011': ['pdam', 'spis'],
 'cluster_10012': ['pdam', 'spis'],
 'cluster_10013': ['pdam', 'spis'],
 'cluster_10014': ['pdam', 'spis'],
 'cluster_10015': ['pdam', 'spis'],
 'cluster_10016': ['pdam', 'spis'],
 'cluster_10017': ['pdam', 'spis'],
 'cluster_10018': ['pdam', 'spis'],
 'cluster_10019': ['pdam', 'spis'],
 'cluster_10020': ['pdam'],
 'cluster_10021': ['pdam', 'spis'],
 'cluster_10022': ['pdam', 'spis'],
 'cluster_10023': ['pdam', 'spis'],
 'cluster_10024': ['pdam', 'spis'],
 'cluster_10025': ['pdam', 'spis'],
 'cluster_10026': ['pdam', 'spis'],
 'clu

In [20]:
lca_losses = defaultdict(int)
for cluster, species in cne_losses.items():
    lca = cluster_lcas[cluster]
    lca_losses[lca] += len(species)

In [21]:
lca_losses

defaultdict(int,
            {'actiniaria': 402,
             'anthozoa': 3501,
             'cnidaria': 3308,
             'hexacorallia': 2629,
             'hydrozoa': 449,
             'medusozoa': 157,
             'robusta': 3040,
             'scleractinia': 13225})

#### Common ancestor distribution

In [22]:
counts = Counter(cluster_lcas.values())
for node in set(cluster_lcas.values()):
    print(node, counts[node])

hydrozoa 469
scleractinia 9516
robusta 5484
anthozoa 733
medusozoa 81
hexacorallia 918
actiniaria 474
cnidaria 465


In [23]:
lca_counts_df = pd.DataFrame(counts.items(), columns=['node', 'cluster_count']).sort_values(by='node')
lca_counts_df.to_csv("lca_counts.tsv", sep="\t", index=False)

In [24]:
lca_counts_df

Unnamed: 0,node,cluster_count
4,actiniaria,474
3,anthozoa,733
0,cnidaria,465
2,hexacorallia,918
6,hydrozoa,469
5,medusozoa,81
7,robusta,5484
1,scleractinia,9516


#### Number of convergent clusters

In [22]:
# Total number of clusters (excluding single CE clusters)
#len(pastml_data.columns)-1

In [25]:
len(cluster_nodes)

30189

In [26]:
len(cluster_lcas) # non-convergent clusters

18140

In [27]:
len(convergent_clusters) # 1 only at the tips

11419

In [28]:
len(additional_convergent_clusters) # 1 at some node but CNE absent at last common ancestor

630

In [27]:
# sum chould be equal to total number of clusters
11 + 66 + 5 # all is well

82

#### Write output to files

In [29]:
with open('convergent_clusters.pickle', 'wb') as handle:
    pickle.dump(convergent_clusters, handle)

In [30]:
cluster_lcas

{'cluster_1': 'cnidaria',
 'cluster_10': 'scleractinia',
 'cluster_100': 'scleractinia',
 'cluster_1000': 'scleractinia',
 'cluster_10000': 'scleractinia',
 'cluster_10001': 'scleractinia',
 'cluster_10002': 'scleractinia',
 'cluster_10003': 'scleractinia',
 'cluster_10004': 'scleractinia',
 'cluster_10005': 'scleractinia',
 'cluster_10006': 'hexacorallia',
 'cluster_10007': 'scleractinia',
 'cluster_10008': 'scleractinia',
 'cluster_10009': 'scleractinia',
 'cluster_1001': 'scleractinia',
 'cluster_10010': 'scleractinia',
 'cluster_10011': 'scleractinia',
 'cluster_10012': 'scleractinia',
 'cluster_10013': 'scleractinia',
 'cluster_10014': 'scleractinia',
 'cluster_10015': 'scleractinia',
 'cluster_10016': 'scleractinia',
 'cluster_10017': 'scleractinia',
 'cluster_10018': 'scleractinia',
 'cluster_10019': 'scleractinia',
 'cluster_1002': 'scleractinia',
 'cluster_10020': 'scleractinia',
 'cluster_10021': 'scleractinia',
 'cluster_10022': 'scleractinia',
 'cluster_10023': 'scleractini

In [31]:
with open('cluster_lcas.pickle', 'wb') as handle:
    pickle.dump(cluster_lcas, handle)