In [None]:
import pandas as pd
import rpy2,os,re
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib_venn as venn
import numpy as np
import seaborn as sns
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
from sklearn.decomposition import PCA as sklearnPCA
import seaborn as sns
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import pickle
#rpy2
import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri
rpy2.robjects.numpy2ri.activate()
from rpy2.robjects import pandas2ri
pandas2ri.activate()
import warnings
from rpy2.rinterface import RRuntimeWarning
warnings.filterwarnings("ignore", category=RRuntimeWarning)


In [None]:
class DESeq_Piano:
    def __init__(self,respath='.'):
        if respath=='.':
            print('Result Path is Not Defined, any result will be written in this directory')
            self.respath=respath
        elif os.path.isdir(respath):
            print('Result Path Exists, please be careful of overwriting')
            self.respath=respath
        else:
            print('Result Path Does Not Exists, creating new directory')
            os.mkdir(respath)
            self.respath=respath
        self.__Init()

            
    def __Init(self):

        
        robjects.r('''
            library('DESeq2')
            library('piano')
            library('Biobase')
            library('snow')
            library('RColorBrewer')
            library('gplots')
            library('snowfall')
        ''')
        self.GSC_KEGG=None
        self.GSC_TF=None
        self.GSC_RM=None
        self.GSC_GO=None
        self.GSC_extraGSC=None
        self.KEGG=None
        self.TF=None
        self.RM=None
        self.GO=None
        self.extraGSC=None
        self.dds=None
        self.deseq_res=None
        self.conds=None
        self.count=None
        self.tpm=None

    def __check_dir(self,extrapath):
        if os.path.isdir('%s/%s/' % (self.respath,extrapath)):
            print('Path %s/%s/ exist, writing in the folder' % (self.respath,extrapath))
        else:
            print('Path %s/%s/ does not exist, making the folder to write the results' % (self.respath,extrapath))
            os.mkdir('%s/%s/' % (self.respath,extrapath))
            
    def DEseq(self,count_df=None,conds_series=None):
        if count_df is not None:
            self.count=count_df
        if conds_series is not None:
            self.conds=conds_series
        self.gene_names=count_df.index
        robjects.globalenv['conds']=np.array(self.conds)
        robjects.globalenv['subject']=np.array(self.conds.index)
        robjects.globalenv['count']=self.count.astype(int).values
        robjects.r('''
            library('DESeq2')
            conds=as.factor(conds)
            coldata <- data.frame(row.names=subject,conds)
            dds <- DESeqDataSetFromMatrix(countData=as.matrix(count),colData=coldata,design=~conds)
            dds <- DESeq(dds)
            print(resultsNames(dds))
            baseMeanPerLvl <- sapply( levels(conds), function(lvl) rowMeans( counts(dds,normalized=TRUE)[,conds == lvl] ) )
            conds2=colnames(baseMeanPerLvl)
            gene_names=rownames(baseMeanPerLvl)

        ''')
        self.dds=robjects.globalenv['dds']
    
    def DEseq_Compare(self, cond1, cond2,save=True):
        if self.dds is None:
            raise ValueError('No DESeq data found, run the DEseq function')
        else:
            print('Comparing %s and %s' % (cond1,cond2))
            robjects.globalenv['dds']=self.dds
            robjects.globalenv['cond1']=cond1
            robjects.globalenv['cond2']=cond2
            robjects.globalenv['gene_names']=np.array(self.gene_names)
            robjects.r('''
                res=results(dds,contrast=c('conds',cond1,cond2))
                res=data.frame(res)
                res$rownames=gene_names
            ''')
            res=pandas2ri.rpy2py_dataframe(robjects.globalenv['res']).set_index('rownames')
            self.deseq_df=res
            if save:
                self.__check_dir('deseq')
                res.to_csv('%s/deseq/deseq_%s_%s.txt' % (self.respath,cond1,cond2),sep='\t')
        
    def loadGSC(self,KEGG='',TF='',RM='',GO='',extraGSC=''):
        if KEGG != '':
            name='KEGG'
            self.__check_dir(name)
            self.GSC_KEGG=self.__loadGSC_execute(KEGG)
        if TF != '':
            name='TF'
            self.__check_dir(name)
            self.GSC_TF=self.__loadGSC_execute(TF)
        if RM != '':
            name='RM'
            self.__check_dir(name)
            self.GSC_RM=self.__loadGSC_execute(RM)
        if GO != '':
            name='GO'
            self.__check_dir(name)
            self.GSC_GO=self.__loadGSC_execute(GO)
        if extraGSC != '':
            name='extraGSC'
            self.__check_dir(name)
            self.GSC_extraGSC=self.__loadGSC_execute(extraGSC)

    
    def __loadGSC_execute(self,GSC):
        robjects.globalenv['GSC']=GSC 
        robjects.r('''
            y=loadGSC(GSC)
        ''')
        return robjects.globalenv['y']
    
    def __PIANO_execute(self,cond1,cond2,GSCtype='',save=True):
        robjects.globalenv['deseq_file']=pandas2ri.py2rpy_pandasdataframe(self.deseq_df)
        if GSCtype == 'KEGG':
            robjects.globalenv['y']=self.GSC_KEGG
            deseq_df=self.deseq_df
            deseq_df.index=[i.upper() for i in deseq_df.index]
            robjects.globalenv['deseq_file']=pandas2ri.py2rpy_pandasdataframe(deseq_df)
        elif GSCtype == 'TF':
            robjects.globalenv['y']=self.GSC_TF
        elif GSCtype == 'RM':
            robjects.globalenv['y']=self.GSC_RM
        elif GSCtype == 'GO':
            robjects.globalenv['y']=self.GSC_GO
            deseq_df=self.deseq_df
            deseq_df.index=[i.upper() for i in deseq_df.index]
            robjects.globalenv['deseq_file']=pandas2ri.py2rpy_pandasdataframe(deseq_df)
        else:
            robjects.globalenv['y']=self.GSC_extraGSC
        robjects.r('''
            DESeq_file=deseq_file
            DESeq_file=DESeq_file[ ,c('log2FoldChange','pvalue')]
            logFC=as.matrix(DESeq_file[,1])
            pval=as.matrix(DESeq_file[,2])
            rownames(logFC)=(rownames(DESeq_file))
            rownames(pval)=(rownames(DESeq_file))
            logFC[is.na(logFC)] <- 0
            pval[is.na(pval)] <- 1
            gsaRes <- runGSA(pval,logFC,gsc=y, geneSetStat="reporter", signifMethod="nullDist", nPerm=1000,ncpus=8)
            res_piano=GSAsummaryTable(gsaRes)
            res_piano$rownames=rownames(res_piano)
        ''')
        res=pandas2ri.rpy2py_dataframe(robjects.globalenv['res_piano']).set_index('rownames').set_index('Name').iloc[0:,0:]#.groupby('Name').sum()
        if save:
            res.to_csv('%s/%s/piano_%s_%s.txt' % (self.respath,GSCtype,cond1,cond2),sep='\t')
        return res

    def PIANO_heatmap(self,cond1,cond2,GSCtype,save=True):
        if GSCtype == 'KEGG':
            deseq=self.KEGG
        elif GSCtype == 'TF':
            deseq=self.TF
        elif GSCtype == 'RM':
            deseq=self.RM
        elif GSCtype == 'GO':
            deseq=self.GO
        else:
            deseq=self.extraGSC
        deseq=deseq[['p adj (dist.dir.up)','p adj (dist.dir.dn)']]
        thr=0.05
        dn=deseq[(deseq.loc[:,deseq.columns.str.contains('dn')]<thr)].drop_duplicates()
        dn=dn.loc[:,dn.columns.str.contains('dn')]
        up=deseq[(deseq.loc[:,deseq.columns.str.contains('up')]<thr)].drop_duplicates()
        up=up.loc[:,up.columns.str.contains('up')]
        temp_dn=np.log10(dn.loc[:,dn.columns.str.contains('adj')][(dn.loc[:,dn.columns.str.contains('adj')]<0.05).sum(1)>0])
        temp_dn.index=[i.split('&')[0] if len(i)>100 else i for i in temp_dn.index]
        temp_up=-np.log10(up.loc[:,up.columns.str.contains('adj')][(up.loc[:,up.columns.str.contains('adj')]<0.05).sum(1)>0])
        temp_up.index=[i.split('&')[0] if len(i)>100 else i for i in temp_up.index]
        name=list(set(temp_up.index).intersection(set(temp_dn.index)))
        temp_df=pd.concat([temp_up.loc[name],temp_dn.loc[name]])
        temp_up=temp_up.loc[list(set(temp_up.index).difference(set(name)))]
        temp_dn=(temp_dn.loc[list(set(temp_dn.index).difference(set(name)))])
        temp_df=temp_df.groupby(temp_df.index).max()
        temp=pd.concat([temp_dn,temp_up,temp_df]).T
        temp['Conditions']=[i.split('p ')[0] for i in temp.index]
#         temp=temp.groupby('Conditions').max(key=abs).T
        temp=temp.loc[temp.groupby('Conditions')['Conditions'].idxmax()].T
        temp=temp.replace(np.inf,temp.replace([np.inf, -np.inf],np.nan).max().max()*1.1)
        temp=temp.replace(-np.inf,temp.replace([np.inf, -np.inf],np.nan).min().min()*1.1)
        temp=temp.replace(-0.0,0)
        temp=temp.loc[:,sorted(temp.columns)]
        min1=temp.min().min()
        max1=temp.max().max()
        cmap=matplotlib.colors.LinearSegmentedColormap.from_list("", ["#0000a5",'#0000d8',"#FFFAF0",'#d80000',"#a50000"])
        temp=temp.replace(np.inf,temp.replace([np.inf, -np.inf],np.nan).max().max()*1.1)
        temp=temp.replace(-np.inf,temp.replace([np.inf, -np.inf],np.nan).min().min()*1.1)
        temp=temp[(temp.abs()>-np.log10(thr)).sum(1)>0].fillna(0)
        temp2=temp.reindex(temp[temp.columns[0]].abs().sort_values(ascending=False).index).copy()
        if temp.shape[0] > 100:
            temp=temp.reindex(temp[temp.columns[0]].abs().sort_values(ascending=False).index)
            temp_d=temp[temp[temp.columns[0]] < 0]
            temp_u=temp[temp[temp.columns[0]] > 0]
            if temp_d.shape[0]<=0:
                temp_u=temp_u.iloc[0:(100-temp_d.shape[0])]
            elif temp_u.shape[0]<=0:
                temp_d=temp_d.iloc[0:(100-temp_u.shape[0])]
            else:
                temp_d=temp_d.iloc[0:50]
                temp_u=temp_u.iloc[0:50]
            temp=pd.concat([temp_d,temp_u])
            g=sns.clustermap(temp,figsize=(2.5, 10),col_cluster=False,row_cluster=True,linewidths=0.1,cmap=cmap,center=0.0,cbar_kws={'label': '-Log10(P-Adjusted)'})
            g.ax_heatmap.set(yticks=[i-0.5 for i in range(1,temp.shape[0]+1)])
            g.ax_heatmap.set_yticklabels(labels=temp.index[list(map(int,g.dendrogram_row.calculate_dendrogram()['ivl']))],size=6)
        else:
            g=sns.clustermap(temp,figsize=(2.5, 10),col_cluster=False,row_cluster=True,linewidths=0.1,cmap=cmap,center=0.0,cbar_kws={'label': '-Log10(P-Adjusted)'})
            g.ax_heatmap.set(yticks=[i-0.5 for i in range(1,temp.shape[0]+1)])
            g.ax_heatmap.set_yticklabels(labels=temp.index[list(map(int,g.dendrogram_row.calculate_dendrogram()['ivl']))],size=8)
        g.ax_heatmap.xaxis.tick_top()
        path=self.respath+'/Figures/'+GSCtype
        if save:
            g.savefig('%s/Heatmap_%s_%s.pdf' % (path,cond1,cond2))
    
    def PIANO(self,cond1,cond2,heatmap=True,save=True):
        if save:
            self.__check_dir('Figures')
        if self.deseq_df is None:
            raise ValueError('No DEseq comparison found, run DEseq_Compare first or assign deseq result dataframe to deseq_res')
        else:
            if self.GSC_KEGG is not None:
                name='KEGG'
                self.KEGG = self.__PIANO_execute(cond1,cond2,GSCtype=name,save=save)
                if heatmap == True:
                    self.__check_dir('/Figures/%s' % name)
                    try:
                        self.PIANO_heatmap(cond1,cond2,GSCtype=name,save=save)
                    except ValueError:
                        pass
            if self.GSC_GO is not None:
                name='GO'
                self.GO = self.__PIANO_execute(cond1,cond2,GSCtype=name,save=save)
                if heatmap == True:
                    self.__check_dir('/Figures/%s' % name)
                    try:
                        self.PIANO_heatmap(cond1,cond2,GSCtype=name,save=save)
                    except ValueError:
                        pass            
            if self.GSC_TF is not None:
                name='TF'
                self.TF = self.__PIANO_execute(cond1,cond2,GSCtype=name,save=save)
                if heatmap == True:
                    self.__check_dir('/Figures/%s' % name)
                    try:
                        self.PIANO_heatmap(cond1,cond2,GSCtype=name,save=save)
                    except ValueError:
                        pass
            if self.GSC_RM is not None:
                name='RM'
                self.RM = self.__PIANO_execute(cond1,cond2,GSCtype=name,save=save)
                if heatmap == True:
                    self.__check_dir('/Figures/%s' % name)
                    try:
                        self.PIANO_heatmap(cond1,cond2,GSCtype=name,save=save)
                    except ValueError:
                        pass
            if self.GSC_extraGSC is not None:
                name='extraGSC'
                self.extraGSC = self.__PIANO_execute(cond1,cond2,GSCtype=name,save=save)
                if heatmap == True:
                    self.__check_dir('/Figures/%s' % name)
                    try:
                        self.PIANO_heatmap(cond1,cond2,GSCtype=name,save=save)
                    except ValueError:
                        pass    
                    
    def __piano2xlsx(self,name,part='ALL'):
        if part != 'ALL':
            files=[i for i in sorted(os.listdir('%s/%s' % (self.respath,name))) if part in i]
        else:
            files=sorted(os.listdir('%s/%s' % (self.respath,name)))
        writer = pd.ExcelWriter('%s/%s_%s.xlsx' % (self.respath,name,part), engine='xlsxwriter')
        for i in files:
            if i.startswith('.'):
                continue
            deseq1=pd.read_csv('%s/%s/%s' % (self.respath,name,i),index_col='Name',sep='\t')[['Genes (tot)','p adj (dist.dir.up)']]
            deseq1.columns=['# of Genes','P-Adj']
            deseq1=deseq1[deseq1['P-Adj']<0.05]
            deseq1['Direction']='UP'
            deseq2=pd.read_csv('%s/%s/%s' % (self.respath,name,i),index_col='Name',sep='\t')[['Genes (tot)','p adj (dist.dir.dn)']]
            deseq2.columns=['# of Genes','P-Adj']
            deseq2=deseq2[deseq2['P-Adj']<0.05]
            deseq2['Direction']='DOWN'
            deseq1=pd.concat([deseq1,deseq2])
            if len(i.replace('piano_','').replace('.txt',''))>31:
                print('sheet name too big, trimming to 30 character')
                n_temp=i.replace('piano_','').replace('.txt','')[0:31]
            else:
                n_temp=i.replace('piano_','').replace('.txt','')
            deseq1.sort_values('P-Adj').to_excel(writer, sheet_name=n_temp)
        writer.save()
        
    def summarizePiano(self,names=['KEGG','GO','TF','RM','extraGSC'],part='ALL'):
        for name in names:
            if os.path.isdir(self.respath+name):
                self.__piano2xlsx(name,part=part)
            else:
                continue
    
    def summarizeDEseq(self,deseq_result='deseq',part='ALL'):
        if os.path.isdir(self.respath+deseq_result):
            if part != 'ALL':
                files=[i for i in sorted(os.listdir(self.respath+'/deseq/')) if part in i]
            else:
                files=sorted(os.listdir(self.respath+'/deseq/'))
            writer = pd.ExcelWriter(self.respath+'/DifferentialExpression_%s.xlsx' % part, engine='xlsxwriter')
            fin=0
            for i in files:
                if i.startswith('.'):
                    continue
                deseq=pd.read_csv(self.respath+'/deseq/%s' % i,index_col=0,sep='\t')
                deseq['abs']=deseq['log2FoldChange'].abs()
                deseq=deseq.sort_values('abs',ascending=False)[['log2FoldChange','pvalue','padj']]

                deseq['Direction']=['UP' if i > 0 else 'DOWN' for i in deseq['log2FoldChange']]
                deseq.columns=['L2FC','P-VALUE','P-ADJ','DIRECTION']
                deseq=deseq.sort_values('P-VALUE')
                if len(i.replace('deseq_','').replace('.txt',''))>31:
                    print('sheet name too big, trimming to 30 character')
                    n_temp=i.replace('deseq_','').replace('.txt','')[0:31]
                else:
                    n_temp=i.replace('deseq_','').replace('.txt','')
                deseq.to_excel(writer, sheet_name=n_temp)
            writer.save()
        else:
            raise ValueError('No DESEq result found')

    def save_object(self,filename='DEseq_PIANO.pkl'):
        with open('%s/%s' % (self.respath,filename), 'wb') as file:
            pickle.dump(self, file)
    
    def load_object(self,filename='DEseq_PIANO.pkl'):
        with open(filename, 'rb') as file:
            res=pickle.load(file)
        return res
    
    def PCA(self,tpm='',conds='',annot='',respath='',save=False,part='',not_part=None):
        if type(conds) != pd.core.series.Series:
            conds = self.conds
        if type(tpm) != pd.core.frame.DataFrame:
            tpm = self.tpm
        filt=sum(list(map(lambda x: (tpm[conds[conds==x].index].mean(1)>1),set(conds))))
        tpm=tpm[(filt>0)]
        if ((part == '') & (not_part == None)):
            part='ALL'
        elif part != '':
            conds=conds[conds.str.contains(part)]
            tpm=tpm[conds.index]
        elif not_part != None:
            part='NOT %s' % not_part
            conds=conds[~conds.str.contains(part)]
            tpm=tpm[conds.index]
        sklearn_pca = sklearnPCA(n_components=2)
        Y_sklearn = sklearn_pca.fit_transform(np.log10(tpm.T+1))
        with plt.style.context('seaborn-white'):
            fig = plt.figure(figsize=(15,8))
            ax = fig.add_subplot(111, aspect='auto')
            ax.grid(False)
            for lab,m,col in zip(sorted(list(set(conds))),['o','*','v','^']*4,(['k']*4)+(['lightblue']*4)+(['yellow']*4)+(['green']*4)):
                x=Y_sklearn[conds==lab, 0]
                y=Y_sklearn[conds==lab, 1]
                plt.scatter(x,y,c=col,label=lab,marker=m,s=200)
        if annot != '':
            for i, txt in enumerate(annot):
                ax.annotate(txt, (Y_sklearn[i, 0], Y_sklearn[i, 1]))
        plt.xlabel('PC1 (%.2f)' % (sklearn_pca.explained_variance_ratio_[0]*100))
        plt.ylabel('PC2 (%.2f)' % (sklearn_pca.explained_variance_ratio_[1]*100))
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.tight_layout()
        if save:
            self.__check_dir('Figures')
            plt.savefig('%s/Figures/PCA_%s.pdf' % (self.respath,part))
        loadings=sklearn_pca.components_.T * np.sqrt(sklearn_pca.explained_variance_)
        loadings=pd.DataFrame(loadings,columns=['x','y'],index=tpm.index)
        return loadings

In [None]:
count=pd.read_csv('../Data/count_geneName.txt',sep='\t',index_col=0)
tpm=pd.read_csv('../Data/tpm_geneName.txt',sep='\t',index_col=0)
conds=pd.read_csv('../Data/metadata.txt',sep='\t',index_col=0)

conds=conds['Conds']
annot=conds.index
tpm=tpm[conds.index]
count=count[conds.index]

In [None]:
k=DESeq_Piano('../Results_DEG/')
k.DEseq(count,conds)
k.loadGSC(KEGG='../../MainLibrary/Current/KEGG_2019_Mouse.gmt',GO='../../MainLibrary/Current/GO_Biological_Process_2018_human.gmt')
loadings=k.PCA(tpm,conds,save=True,annot=annot.tolist())

In [None]:
cond2=''
cond1=''
k.DEseq_Compare(cond1,cond2,save=True)
k.PIANO(cond1,cond2,heatmap=False,save=True)

In [None]:
k.summarizeDEseq()
k.summarizePiano()