In [28]:
import os
import shutil
import pandas as pd

from ete3 import Tree
from copy import deepcopy
from subprocess import Popen, PIPE
from utils.revbayes import revbayes_template, treealign, read_nexus_sum_into_ete3

In [29]:
# load ground truth phylogeny
gt = Tree('phylogeny.nh')

# uncomment to visualize phylogeny
# print(gt)
# gt.show()

In [30]:
run = '/mnt/cluster/Training_Results/rove/ROVE_RESNET50_FROZEN_NORMALIZE_multisim4_2022-9-8-9-46-4'

#out dir cannot have certain characters in it, or perhaps be too long, otherwise revbayes will not run on it
out_dir = '/home/rob/revbayes_run' 

print_output = True

nmi_fp = os.path.join(run,'CSV_Logs',f'Data_Test_discriminative_nmi.csv')
traits_fp = os.path.join(run, 'traits_discriminative_Test.nex')

if os.path.exists(nmi_fp) and os.path.exists(traits_fp):
    nmi = pd.read_csv(nmi_fp)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    sub_out = [run]

    for _set in ['Train','Val','Test']:
        nmi_fp = os.path.join(run,'CSV_Logs',f'Data_{_set}_discriminative_nmi.csv')
        recall_fp = os.path.join(run,'CSV_Logs',f'Data_{_set}_discriminative_e_recall.csv')
        if os.path.exists(nmi_fp) and os.path.exists(recall_fp):
            nmi = pd.read_csv(nmi_fp)
            recall = pd.read_csv(recall_fp)
            if len(nmi) < 50:
                continue
            sub_out += [nmi.max().iloc[0]*100,recall.max().iloc[0]*100]

    for _set in ['Val','Test']:
        out_tree_dir = os.path.join(out_dir, _set)
        if not os.path.exists(out_tree_dir):
            os.makedirs(out_tree_dir)

        out_tree_fp = os.path.join(out_tree_dir, 'OUTSUMFILE')
        if not os.path.exists(out_tree_fp):

            traits_fp = os.path.join(run, f'traits_discriminative_{_set}.nex')
            #copy traits file to somewhere we can run,
            new_traits_fp = os.path.join(out_tree_dir,f'traits_discriminative_{_set}.nex')
            shutil.copy(traits_fp, new_traits_fp)

            #make rev file somewhere we can run, pointing to traits file and output files for trees
            bayes_str = revbayes_template.format(_dir=out_tree_dir, _set=_set)
            bayes_fp = os.path.join(out_dir,'run_traits.rev')
            with open(bayes_fp, 'w') as f:
                f.write(bayes_str)
            assert os.path.exists(bayes_fp), bayes_fp
            print(bayes_fp)

            print('About to start')
            #run rev file
            cmd = ['singularity', 'run', '--app', 'rb', 'RevBayes_Singularity_1.1.1.simg', bayes_fp]
            if print_output == True:
                with Popen(cmd, stdout=PIPE, bufsize=1, universal_newlines=True) as p:
                    for line in p.stdout:
                        print(line, end='') # process line here
            else:
                output = subprocess.run(cmd, capture_output=True)

        #read in tree and convert to newick format
        tree = read_nexus_sum_into_ete3(out_tree_fp)

        #make as ete3, compare tree with ground truth tree.
        #gtt = make_ground_truth_phylogenetic_tree('cub200_phylogeny.csv', tree2=tree, level='species')
        if _set == 'Val':
            val_gt = deepcopy(gt)
            val_gt.prune([x.name for x in tree.get_leaves()])
            #save align score
            align, align_max = treealign(tree,val_gt)
        else:
            test_gt = deepcopy(gt)
            test_gt.prune([x.name for x in tree.get_leaves()])
            #save align score
            align, align_max = treealign(tree,test_gt)


        sub_out += [align, align_max]

ser = pd.Series(sub_out,index=['run','Train_nmi','Train_recall@1','Val_nmi','Val_recall@1','Test_nmi','Test_recall@1','Val_Align','Val_Max','Test_Align', 'Test_Max'])
ser       

run               /mnt/cluster/Training_Results/rove/ROVE_RESNET...
Train_nmi                                                 91.470793
Train_recall@1                                            96.355753
Val_nmi                                                   70.826816
Val_recall@1                                               87.70175
Test_nmi                                                  66.911117
Test_recall@1                                             86.163522
Val_Align                                                   7.89753
Val_Max                                                        21.5
Test_Align                                                  3.52493
Test_Max                                                       10.5
dtype: object