In [1]:
import ete3
import os
from pymongo import MongoClient
from bson.objectid import ObjectId
from pymongo.errors import DuplicateKeyError
from pprint import pprint
import json

In [2]:
TREES_PATH = '../data/1KP/species_level/trees/'
INPUT_GROUP_ID = 1
OUTPUT_FILENAME = 'input_group_%d.json' % INPUT_GROUP_ID
TITLE = '69 Species Trees'

In [3]:
# Open MongoDB connection
mongo_client = MongoClient('localhost', 27017)
mongo_db = mongo_client.visphy_dev
tree_col = mongo_db.tree
branch_col = mongo_db.branch
input_group_col = mongo_db.inputGroup
entity_col = mongo_db.entity

In [4]:
# Add input group to MongoDB
input_group = input_group_col.find_one_and_update({'inputGroupId': INPUT_GROUP_ID}, 
                                            {'$set': {'inputGroupId': INPUT_GROUP_ID, 'title': TITLE, 'trees': []}})
print input_group
input_group_id = input_group['_id']
print 'inputGroup oid: ', input_group_id

{u'inputGroupId': 1.0, u'_id': ObjectId('5837a01ac637b3313f44e8fb'), u'trees': [ObjectId('5837a1af7443ee13c12d70d9'), ObjectId('5837a1af7443ee13c12d70d8'), ObjectId('5837a1af7443ee13c12d70d5'), ObjectId('5837a1af7443ee13c12d70d4'), ObjectId('5837a1af7443ee13c12d70d7'), ObjectId('5837a1af7443ee13c12d70d6'), ObjectId('5837a1af7443ee13c12d70d1'), ObjectId('5837a1af7443ee13c12d70d0'), ObjectId('5837a1af7443ee13c12d70d3'), ObjectId('5837a1af7443ee13c12d70d2'), ObjectId('5837a1af7443ee13c12d70eb'), ObjectId('5837a1af7443ee13c12d70ec'), ObjectId('5837a1af7443ee13c12d70ea'), ObjectId('5837a1af7443ee13c12d70ef'), ObjectId('5837a1af7443ee13c12d70ed'), ObjectId('5837a1af7443ee13c12d70ee'), ObjectId('5837a1af7443ee13c12d70ca'), ObjectId('5837a1af7443ee13c12d70cb'), ObjectId('5837a1af7443ee13c12d70cc'), ObjectId('5837a1af7443ee13c12d70cd'), ObjectId('5837a1af7443ee13c12d70ce'), ObjectId('5837a1af7443ee13c12d70cf'), ObjectId('5837a1af7443ee13c12d70f7'), ObjectId('5837a1af7443ee13c12d70f6'), ObjectId

In [5]:
# Open tree files
trees = {}
for dirpath, dirnames, filenames in os.walk(TREES_PATH):
    for filename in filenames:
        if filename.endswith('.tre'):
            with open(os.path.join(dirpath, filename)) as f:
                new_tree = {
                    'inputGroupId': INPUT_GROUP_ID,
                    'newickString': f.readline(), 
                    'name': filename[:-4], 
                    'type': 'species',
                    'rfDistance': {}
                }
                insert_res = tree_col.insert_one(new_tree)
                trees[str(insert_res.inserted_id)] = new_tree

In [6]:
# Put all tree ids to input group document in Mongo
input_group_col.find_one_and_update({'_id': input_group_id}, {'$set': {'trees': [ObjectId(k) for k in trees]}}) 

{u'_id': ObjectId('5837a01ac637b3313f44e8fb'),
 u'inputGroupId': 1.0,
 u'title': u'69 Species Trees',
 u'trees': []}

In [7]:
# Parse the trees
entities = {}
ent_names = {}
for tid, tval in trees.iteritems():
    root = ete3.Tree(tval['newickString'], format=2)
    # Make sure the tree is strictly bifurcated
    root.resolve_polytomy(recursive=True)
    tval['ete_tree'] = root
    # Process the leaves
    for leaf in root:
        if leaf.name not in ent_names:
            ent_id = None
            try:
                ent_id = entity_col.insert_one({'name': leaf.name, 'type': 'species', 'inputGroupId': INPUT_GROUP_ID}).inserted_id
            except DuplicateKeyError:
                ent_id = entity_col.find_one({'name': leaf.name}).get('_id')
            ent_names[leaf.name] = ent_id
        leaf.add_feature('entity_id', ent_names[leaf.name])
            
    # Process the branches
    branches = []
    root_branch_id = branch_col.insert_one({'length': 0, 'support': 0, 'tree': ObjectId(tid), 'entities': [x.entity_id for x in root],
                                            'inputGroupId': INPUT_GROUP_ID}).inserted_id
    root.add_feature('branch_id', root_branch_id)
    
    for node in root.iter_descendants():
        branch_id = branch_col.insert_one({'length': node.dist, 'support': node.support, 
                                           'parent': node.up.branch_id, 'tree': ObjectId(tid),
                                          'isLeaf': node.is_leaf(), 'correspondingBranches': {},
                                           'inputGroupId': INPUT_GROUP_ID,
                                          'entities': [x.entity_id for x in node]}).inserted_id
        node.add_feature('branch_id', branch_id)
    for node in root.traverse():
        if not node.is_leaf():
            branch_col.find_one_and_update({'_id': node.branch_id}, 
                                           {'$set': {'left': node.children[0].branch_id, 
                                                     'right': node.children[1].branch_id}})
            

In [8]:
for tid, tval in trees.iteritems():
    tree_col.find_one_and_update({'_id': ObjectId(tid)}, {'$set': {'rootBranch': tval['ete_tree'].branch_id}})

In [9]:
# Calculate RF distance between any two trees
for tid1, tval1 in trees.iteritems():
    for tid2, tval2 in trees.iteritems():
        if tid1 < tid2:
            res = tval1['ete_tree'].robinson_foulds(tval2['ete_tree'])
            #print tval1
            tval1['rfDistance'][tid2] = res[0]
            tval2['rfDistance'][tid1] = res[0]
    # break



In [10]:
# Update rfDistance in MongoDB
for tid, tval in trees.iteritems():
    tree_col.find_one_and_update({'_id': ObjectId(tid)}, {'$set': {'rfDistance': tval['rfDistance']}})

In [11]:
for tid, tval in trees.iteritems():
    for n in tval['ete_tree'].traverse():
        n.add_feature('entity_set', set([x.entity_id for x in n]))

In [12]:
# Find corresponding nodes for each node in each tree

# Calc Jaccard distance between two sets of entities that defined by internal node x and y
def get_jaccard(x, y):
    a = x.entity_set
    b = y.entity_set
    c = a.intersection(b)
    return float(len(c)) / float((len(a) + len(b) - len(c)))

# Find corresponding node of "node" from t
# Return the corresponding node, the jaccard index
def find_corr_node(node, t):
    max_jaccard = 0
    num_entities_diff = 0
    corr_node = 0
    epsilon = 0.001
    for x in t.traverse('levelorder'):
        d = get_jaccard(node, x)
        if d > max_jaccard or (abs(d - max_jaccard) < epsilon and abs(len(node) - len(x)) < num_entities_diff):
            max_jaccard = d
            corr_node = x
            num_entities_diff = abs(len(node) - len(x))
            if abs(d - 1.0) < epsilon:
                break
    return corr_node, max_jaccard


        

In [13]:
# For every pair of tree, iterate over every node in trees
# TODO TOO SLOW!!!!!
for tid1, tval1 in trees.iteritems():
    t1 = tval1['ete_tree']
    for n in t1.traverse('levelorder'):
        corr = {}
        for tid2, tval2 in trees.iteritems():
            if tid1 != tid2:
                t2 = tval2['ete_tree']
           
                c, max_jaccard = find_corr_node(n, t2)
                # print 'corr found', c, max_jaccard
                corr[tid2] = {
                    'branchId': c.branch_id,
                    'jaccard': max_jaccard
                }
        # n.add_feature('corr_branches', corr)
        branch_col.find_one_and_update({'_id': n.branch_id}, {'$set': {'correspondingBranches': corr}})
    print tid1
                                
            

5837a6cf7443ee13f7d126f9
5837a6cf7443ee13f7d126f8
5837a6cf7443ee13f7d1270c
5837a6cf7443ee13f7d126f1
5837a6cf7443ee13f7d126f0
5837a6cf7443ee13f7d126f3
5837a6cf7443ee13f7d126f2
5837a6cf7443ee13f7d126f5
5837a6cf7443ee13f7d126f4
5837a6cf7443ee13f7d126f7
5837a6cf7443ee13f7d126f6
5837a6cf7443ee13f7d1270e
5837a6cf7443ee13f7d12706
5837a6cf7443ee13f7d126ea
5837a6cf7443ee13f7d126eb
5837a6cf7443ee13f7d126ec
5837a6cf7443ee13f7d126ed
5837a6cf7443ee13f7d126ee
5837a6cf7443ee13f7d126ef
5837a6cf7443ee13f7d12701
5837a6cf7443ee13f7d1270d
5837a6cf7443ee13f7d12708
5837a6cf7443ee13f7d12709
5837a6cf7443ee13f7d126d3
5837a6cf7443ee13f7d126d2
5837a6cf7443ee13f7d126d1
5837a6cf7443ee13f7d126d0
5837a6cf7443ee13f7d126d7
5837a6cf7443ee13f7d126d6
5837a6cf7443ee13f7d126d5
5837a6cf7443ee13f7d126d4
5837a6cf7443ee13f7d12703
5837a6cf7443ee13f7d1270b
5837a6cf7443ee13f7d126d9
5837a6cf7443ee13f7d126d8
5837a6cf7443ee13f7d12700
5837a6cf7443ee13f7d12707
5837a6cf7443ee13f7d1270a
5837a6cf7443ee13f7d12704
5837a6cf7443ee13f7d12711


In [17]:
# Format branches
def format_branch(node):
    if node.is_root():
        return {}
    else:
        return {
            'length': node.dist,
            'support': node.support,
            'isLeaf': node.is_leaf(),
            'left': str(node.children[0].branch_id) if not node.is_leaf() else None,
            'right': str(node.children[1].branch_id) if not node.is_leaf() else None,
            'entities': [x.entity_id for x in node],
            #'correspondingBranches': node.correspondingBranches
        }

for _, tval in trees.iteritems():
    tval['branches'] = {node.branch_id: format_branch(node) for node in tval['ete_tree'].traverse('levelorder')}
    # print tval['branches']
    # break


In [21]:
# Format trees
def format_tree(t):
    return {
        'type': t['type'],
        'entities': [x.entity_id for x in t['ete_tree']],
        'branches': t['branches'],
        'rootBranch': t['ete_tree'].branch_id,
        'rfDistance': t['rfDistance']      
    }
output_trees = {tid: format_tree(tval) for tid, tval in trees.iteritems()}


In [22]:
# Save JSON file
dataset = {
    'inputGroupId': INPUT_GROUP_ID,
    'title': TITLE,
    'trees': output_trees
}

fp = open(OUTPUT_FILENAME, 'w')
json.dump(dataset, fp)
fp.close()

TypeError: ObjectId('5837a6d07443ee13f7d12715') is not JSON serializable