In [1]:
import pandas
import numpy
import os
import sys
from scipy.spatial import distance
from scipy.cluster import hierarchy
from sklearn.metrics import silhouette_score, homogeneity_score, silhouette_samples, adjusted_mutual_info_score
from matplotlib import pyplot as plt

In [2]:
categories_dict = {
    'eukaryotes':{
            'column':'taxonomy_phylum',
            'key':'Phylum',
            },
    'bacteria':{
            'column':'taxonomy_phylum',
            'key':'Phylum',
            },
    'archaea':{
            'column':'taxonomy_class',
            'key':'Class'
            }
}

scenarios_dict = {
    'rRNA':{
        'dir':'rib_analysis/',
        'color':'#56862e',
        'label': 'Ribosomal Sequences'
    },
    'proteostasis':{
        'dir':'proteostasis_functional_analysis/',
        'color':'#026690',
        'label': 'Proteostasis Machinery'
    },
    'hsp40':{
        'dir':'hsp_analysis/hsp40/',
        'color':'#a62c25',
        'label': 'HSP40 Sequences'
    },
    'hsp70':{
        'dir':'hsp_analysis/hsp70/',
        'color': '#aa5325',
        'label': 'HSP70 Sequences'
    }
}

In [3]:
subcategories = ['archaea', 'bacteria', 'eukaryotes']
scenarios = ['rRNA', 'proteostasis', 'hsp40', 'hsp70']

results = {}
for subcategory in subcategories:
    organisms_df = pandas.read_csv('./files/organisms_with_tax.tsv', sep='\t')
    organisms_df = organisms_df.loc[organisms_df.taxonomy == subcategory, :]
    
    organisms = organisms_df.organism_name
    column = categories_dict[subcategory]['column']
    organisms_taxonomy = dict(zip(organisms_df.organism_name, organisms_df.loc[:,column]))

    ref_labels = []
    for organism in organisms:
        tax = organisms_taxonomy[organism]
        ref_labels.append(tax)
    
    length = len(list(set(ref_labels)))
    N = int(numpy.math.ceil(1.2*length))
    categories_dict.setdefault(subcategory, {}).update({'classes':length})

    for scenario in scenarios:
        main_dir = './'+scenarios_dict[scenario]['dir']
        tmp_data = []
        df = pandas.read_csv(main_dir+'final_distance_matrix.csv', index_col=0)
        df = df.loc[organisms,organisms]
        matrix = df.values
        data = []
        for n in range(N, 1, -1):
            condensed_matrix = distance.squareform(matrix, force='to_vector')
            Z = hierarchy.linkage(condensed_matrix, method='ward')
            labels = list(hierarchy.fcluster(Z, criterion='maxclust',t=n))
            #aim = round(adjusted_mutual_info_score(labels_true=ref_labels, labels_pred=labels), 2)
            hs = round(homogeneity_score(labels_true=ref_labels, labels_pred=labels),2)
            tmp_data.append([n,hs])

        tmp_df = pandas.DataFrame(tmp_data, columns=['n_clusters', 'aim'])
        results.setdefault(subcategory, {}).update({scenario:tmp_df})             

In [4]:
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerLine2D
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['FreeSans', ]
matplotlib.rcParams['mathtext.fontset'] = 'custom'
matplotlib.rcParams['axes.titlepad'] = 4





class SymHandler(HandlerLine2D):
    def create_artists(self, legend, orig_handle,xdescent, ydescent, width, 
                       height, fontsize, trans):
        xx= self.y_align*height
        return super(SymHandler, self).create_artists(legend, orig_handle,xdescent, 
                     xx, width, height, fontsize, trans)
SH = SymHandler()
SH.y_align = 0.3

In [6]:
plt.close()

fig = plt.figure(1, figsize=(4.5,8), frameon=False)
grids = GridSpec(11, 1, hspace=0.25, wspace=0)
step = 3

for i, subcategory in enumerate(subcategories):
    ax = fig.add_subplot(grids[i*step+i:(i+1)*step+i,0])
    for scenario in scenarios:
        tmp = results[subcategory][scenario]
        ax.plot(tmp.n_clusters, tmp.aim, label=scenario, linewidth=2, color=scenarios_dict[scenario]['color'])
    ax.axvline(x=categories_dict[subcategory]['classes'], linewidth=1, color='black')
    ax.grid(True, linestyle='--', alpha=0.5, linewidth=0.7)
    ax.set_ylim([0,1])
    ax.tick_params(labelsize=8)
    #ax.set_yticklabels(ax.get_yticklabels(), fontdict={'size':8})
    ax.set_title(subcategory.capitalize() + ' ('+categories_dict[subcategory]['key']+'-level)', fontsize=12)
    ax.set_xlabel('Number of Clusters', fontsize=10)
    ax.set_ylabel('Homogeneity Score', fontsize=10)
    

patches = []
patches_labels = []

p = Line2D([0], [0], marker='o', color=scenarios_dict['rRNA']['color'], 
           markerfacecolor=scenarios_dict['rRNA']['color'],
           markersize=15, label=scenarios_dict['rRNA']['label'])
patches.append(p)
patches_labels.append(scenarios_dict['rRNA']['label'])

p = Line2D([0], [0], marker='o', color=scenarios_dict['proteostasis']['color'], 
           markerfacecolor=scenarios_dict['proteostasis']['color'],
           markersize=15, label=scenarios_dict['proteostasis']['label'])
patches.append(p)
patches_labels.append(scenarios_dict['proteostasis']['label'])

p = Line2D([0], [0], marker='o', color=scenarios_dict['hsp40']['color'], 
           markerfacecolor=scenarios_dict['hsp40']['color'],
           markersize=15, label=scenarios_dict['hsp40']['label'])
patches.append(p)
patches_labels.append(scenarios_dict['hsp40']['label'])

p = Line2D([0], [0], marker='o', color=scenarios_dict['hsp70']['color'], 
           markerfacecolor=scenarios_dict['hsp70']['color'],
           markersize=15, label=scenarios_dict['hsp70']['label'])
patches.append(p)
patches_labels.append(scenarios_dict['hsp70']['label'])



plt.legend(handles=patches, 
           labels=patches_labels, 
           bbox_to_anchor=(0.38, -0.80, 0.6, 0.5),
           fontsize=10, ncol=2, fancybox=False, framealpha=0, 
           handlelength=1, handletextpad=1,
           handleheight=1.5,
           labelspacing=1.5,
           handler_map={matplotlib.lines.Line2D: SH})

main_dir = './phyla_figure/'
if not os.path.exists(main_dir):
    os.makedirs(main_dir)
#results_df.to_csv(main_dir+'scores.tsv', sep='\t')
#plt.savefig(main_dir+'hs_plots.tiff', dpi=800, format='tiff', bbox_inches='tight', pad_inches=0.05)
#plt.savefig(main_dir+'hs_plots.png', dpi=800, format='png', bbox_inches='tight', pad_inches=0.05)
plt.savefig(main_dir+'hs_plots.pdf', dpi=600, format='pdf', bbox_inches='tight', pad_inches=0.05)