In [1]:
import random
import os
import sys 
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import numpy as np
from Bio import SeqIO
import gzip
random.seed(0)

In [2]:
os.chdir("/Users/nk/Documents/backupped/Research/YachieLabLocal/FRACTAL/data/NK_0145")
try:
    os.mkdir("figures")
except:
    None

In [3]:
mpl.rcParams['font.family']       = 'Helvetica'
mpl.rcParams['font.sans-serif']   = ["Helvetica","Arial","DejaVu Sans","Lucida Grande","Verdana"]
mpl.rcParams['figure.figsize']    = [4,3]
mpl.rcParams['font.size']         = 9
mpl.rcParams["axes.labelcolor"]   = "#000000"
mpl.rcParams["axes.linewidth"]    = 1.0 
mpl.rcParams["xtick.major.width"] = 1.0
mpl.rcParams["ytick.major.width"] = 1.0
cmap1 = plt.cm.tab10
cmap2 = plt.cm.Set3  
colors1 = [cmap1(i) for i in range(0,10)]
colors2 = [cmap2(i) for i in range(0,12)] 
plt.style.use('default')

In [4]:
def generate_cmap(colors):
    color_list = []
    values = range(len(colors))
    vmax   = int(np.max(values))
    for v, c in enumerate(colors):
        color_list.append( (v*1.0/ vmax, c) )
    return LinearSegmentedColormap.from_list('custom_cmap', color_list)


def sparsemat_heatmap(sparse_mat_fname,mode=""):
    x_list=[]
    y_list=[]

    with open(sparse_mat_fname, 'r') as handle:
        for line in handle:
            line=line.split("\n")[0]
            x=float(line.split(",")[1])
            y=float(line.split(",")[0])
            if (x not in x_list): x_list.append(int(x))
            if (y not in y_list): y_list.append(int(y))
            x_list.sort() # x_list : threshold
            y_list.sort()  # y_list : subsample
            x_list.reverse()

    #print(x_list)
    #print(y_list)

    x_dict={}
    y_dict={}
    for i in range(len(x_list)): x_dict[x_list[i]]=i
    for i in range(len(y_list)): y_dict[y_list[i]]=i

    matrix=np.zeros((len(x_list),len(y_list)))
    for i in range(len(matrix)):
        for j in range(len(matrix[0])):
            matrix[i][j]= np.nan

    with open(sparse_mat_fname, 'r') as handle:
        for line in handle:
            line=line.split("\n")[0]
            splitted=line.split(",")
            x_idx=x_dict[int(splitted[1])]
            y_idx=y_dict[int(splitted[0])]
            if (mode=="RunTime"):
                if(len(splitted[7])==0): matrix[x_idx][y_idx] = np.nan
                else: matrix[x_idx][y_idx] = float(splitted[7])
            if (mode=="NRFdist"):
                if(len(splitted[10])==0): matrix[x_idx][y_idx] = np.nan
                else: matrix[x_idx][y_idx] = (1-float(splitted[10]))*100
            if (mode=="Memory"):
                if(len(splitted[5])==0): matrix[x_idx][y_idx] = np.nan
                matrix[x_idx][y_idx] = float(splitted[5])
            if (mode=="Coverage"):
                if(len(splitted[9])==0): matrix[x_idx][y_idx] = np.nan
                else: matrix[x_idx][y_idx] = float(splitted[9])/float(splitted[4])
    
    if(mode=="RunTime" or mode=="Memory"):
        offset=sum(matrix[0])/len(matrix[0])
        for i in range(len(matrix)):
            for j in range(len(matrix[0])):
                matrix[i][j]= np.log2(float(matrix[i][j])/float(offset))

    return (x_list, y_list, matrix)

title={'nj':"RapidNJ (NJ)",'mp': "RAxML (MP)", 'ml': "FastTree (ML)"}
subtitle=['Coverage',"Accuracy (%)",]
mode=['Coverage','NRFdist']
cmap_list=[mpl.cm.RdPu_r,mpl.cm.plasma]

center=[1,1,0,0.5]
max_min=[[0,1],[0,100]]
vmax=[1,None]

fig = plt.figure(figsize=(2,2))
for i in range(2):
    j=0
    for method in ['nj','mp','ml']:
        ax1 = fig.add_axes([0.1+j*0.9,0.1+i*1.1,0.8,0.8],label="a")
        if(i==1):
            ax1.set_title(title[method],fontsize=14)

        x_y_mat=sparsemat_heatmap("result."+method+".csv",mode=mode[i])

        # focus on the region where accuracy and coverage is enough
        NRFmatrix=sparsemat_heatmap("result."+method+".csv",mode="NRFdist")[2]
        offset=sum(NRFmatrix[0])/len(NRFmatrix[0])
        for l in range(len(NRFmatrix)):
            for m in range(len(NRFmatrix[0])):
                NRFmatrix[l][m]=NRFmatrix[l][m]-offset

        coveragematrix=sparsemat_heatmap("result."+method+".csv",mode="Coverage")[2]
        value_matrix=x_y_mat[2]

        for l in range(len(value_matrix)):
            for m in range(len(value_matrix[0])):
                if(np.isnan(value_matrix[l][m]) or np.isnan(NRFmatrix[l][m]) or np.isnan(coveragematrix[l][m])): 
                    value_matrix[l][m]=np.nan

        x_ticklabels=x_y_mat[0]
        y_ticklabels=x_y_mat[1]
        
        x_ticklabels[0]="Original"

        x_y_df=pd.DataFrame(data=value_matrix, index=x_ticklabels, columns=y_ticklabels)

        with sns.axes_style("dark"):
            #draw heatmap
            ax1.patch.set_facecolor('grey')
            sns.heatmap(x_y_df, ax=ax1,cmap=cmap_list[i],vmin=max_min[i][0], vmax=max_min[i][1],cbar=False)
            if(i==0 and j==1): ax1.set_xlabel("Subsample Size (sequences)",fontsize=14)
            if(j==0): ax1.set_ylabel('Threshold (sequences)',fontsize=14)

        newxticklocs=[3.5,7.5,11.5,15.5,19.5]
        newxlabels=['200','400','600','800','1000']
        newyticklocs=[0.5, 4.5, 8.5, 12.5, 16.5, 20.5, 24.5] 
        newylabels=['Original', '12000', '10000', '8000', '6000', '4000', '2000']
        ax1.set_xticks(np.array(newxticklocs)) 
        ax1.set_yticks(np.array(newyticklocs)) 
        ax1.set_xticklabels(newxlabels)
        ax1.set_yticklabels(newylabels)
        
        if(i!=0):
            plt.tick_params(labelbottom=False,bottom=False)
        if(j!=0):
            plt.tick_params(labelleft=False,left=False)

        if(j==2):
            ax3 = fig.add_axes([0.1+j*0.9+0.9,0.1+i*1.1,0.05,0.8])
            cmap = cmap_list[i]
            norm = mpl.colors.Normalize(vmin=max_min[i][0], vmax=max_min[i][1])
            #cb1 = mpl.colorbar.ColorbarBase(ax3, cmap=cmap, norm=norm, orientation='horizontal')
            cb1 = mpl.colorbar.ColorbarBase(ax3, cmap=cmap, norm=norm, orientation='vertical')
            cb1.set_label(subtitle[i],fontsize=14)
        j+=1

#plt.text(-7.3,140,"Run time log fold-change",fontsize=14)
#plt.text(2,3,"Mermory usage log fold-change",fontsize=14)
fig.savefig("figures/NK_0145_mutation_heatmap_paper.pdf",bbox_inches="tight")
plt.close()