# Draw BarChart

In [None]:
# output_Stage
import pygraphviz as pgv
import math
from IPython.display import SVG

# output_moreInfo
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# OutputStage

In [None]:
class OutputStage:
    famStageMatrix = None
    outputPath = None
    famBASE = None 
    famSeg2lb = None
    famLb2seg = None
    
    g = None
    gapLb = None
    
    stageLen = None
    stage2node_dict = dict() 
    stage2Hk_dict = dict()
    stage2common_dict = dict()
    
    threshold = 0.0
    
    def __init__(self, stageMatrix, outputPath, BASE, Motif):
#         print("==Starting drawing stages==")
        self.famStageMatrix = stageMatrix
        self.outputPath = outputPath
        self.famBASE = BASE
        self.famSeg2lb = Motif.getMoti_seg2lb()
        self.famLb2seg= Motif.getMoti_lb2seg()
        
        self.g = pgv.AGraph(directed=True)
        
        self.stageLen = len( self.famStageMatrix[self.famBASE])
        
        self.g = self.__nodeBuilt(self.g)
        self.g = self.__nodeConnect(self.g)
        self.g = self.__setLabel(self.g)
        self.__saveOutput(self.g)
        
    #def __len__(self):
    #def __iter__(self):
    #def __getitem__(self, key):
    #def __str__(self):
    
    #===private function 
    def __nodeBuilt(self, g): # stage2node, stage2Hk
        stage2node_dict = {i:[] for i in range(self.stageLen)}
        stage2Hk_dict = dict() 
        
        for i in range(self.stageLen):
            n_dict = dict()

            for k in self.famStageMatrix:
                lb_key = self.famSeg2lb[tuple(self.famStageMatrix[k][i])] #===== lb_key=> M1.M2
                if lb_key not in n_dict.keys():
                    n_dict[lb_key] = [k]
                else:
                    n_dict[lb_key] = n_dict[lb_key] + [k]

            for lb in n_dict:
                num_stage2HK  = len(n_dict[lb])
                num_stage2moti = len(self.famLb2seg[lb])

                if g.number_of_nodes()==0: 
                    g.add_node(1,label=lb+ ':'+ str(num_stage2HK)+ r"\nlen:"+str(num_stage2moti)
                               ,id=';'.join(n_dict[lb])) 
                    stage2node_dict[i]=stage2node_dict[i]+[1] 
                else:
                    g.add_node(g.number_of_nodes()+1, label=lb+':'+str(num_stage2HK) +r"\nlen:" 
                               +str(num_stage2moti), id=';'.join(n_dict[lb]))
                    stage2node_dict[i]=stage2node_dict[i]+[g.number_of_nodes()]

            stage2Hk_dict[i+1] = n_dict 
            
        self.stage2node_dict = stage2node_dict
        self.stage2Hk_dict = stage2Hk_dict

#         print("---Node Build Done---")
        return g
        
    
    def __nodeConnect(self, g):
        for k in self.stage2node_dict:
#             print(k, "stage")       
            if k < self.stageLen -1:
                cur_list = [g.get_node(i) for i in self.stage2node_dict[k]]
                next_list = [g.get_node(i) for i in self.stage2node_dict[k+1]]
                for n_cur in cur_list:
                    for n_next in next_list:
                        for log_cur in n_cur.attr['id'].split(';'):
                            for log_next in n_next.attr['id'].split(';'):
                                if log_cur == log_next:
                                    if not g.has_edge(n_cur,n_next):
                                        common_set=set(n_cur.attr['id'].split(';')).intersection(
                                            set(n_next.attr['id'].split(';')))
                                        g.add_edge(n_cur,n_next,label=len(common_set))

        #relabel gap
        import re
        if tuple(['=']) in self.famSeg2lb:
            self.gapLb = self.famSeg2lb[tuple(['='])]
            for n in g.nodes():
                if n.attr['label'].split(':')[0] == self.gapLb:
                    tok = re.split('\W+n*',n.attr['label'])
                    n.attr['label']='gap'+ ':' + str(tok[1]) + r"\n" + tok[2] + ':' + tok[3]
                    n.attr['shape']='diamond'
                    n.attr['fillcolor']='yellow'
#             print("\n\tgap: ",self.gapLb)
        else:
            print("Warning!")
            print("\tDoes not run others, because only one stage is unnecessary if there is no gap!")
    
#         print("---Node connect done---")
        return g
        
    
    def __setLabel(self, g):
        # add extra labels (entropy per stage)
        last_node = ''
        for i in self.stage2node_dict:
            # stage label and entropy
            g.add_node(len(g.nodes())+1)
            g.get_node(len(g.nodes())).attr['shape'] = 'plaintext'
            n_list = [len(g.get_node(n).attr['id'].split(';')) for n in self.stage2node_dict[i] 
                      if g.get_node(n).attr['id']]
            #for n in stage2node_dict[i]: print i,n,g.get_node(n)
            e = str(round(self.__entropy(n_list), 2))
            g.get_node(len(g.nodes())).attr['label'] = 'stage '+str(i+1)+'\nH='+e
            g.get_node(len(g.nodes())).attr['id'] = 'extra'
            # make them same level
            g.add_subgraph([len(g.nodes())] + [n for n in self.stage2node_dict[i]], rank='same')

        stage2common_dict = {l:0 for l in self.stage2node_dict}
        for n in g.nodes():
            if g.get_node(n).attr['label'].startswith('stage'):
                stage = int(g.get_node(n).attr['label'].split()[1]) - 1
                entropy = float(g.get_node(n).attr['label'].split('=')[1]) 
                if entropy <= self.threshold:
                    stage2common_dict[stage] = 1

        # align stage label
        for n in sorted(g.nodes()): 
            if n.attr['id'] == 'extra':
                if int(n) < len(g.nodes())-1:
                    g.add_edge(n, int(n)+1, weight=10, style='invis')
                    
#         print("---Set Label Done---")

        self.stage2common_dict = stage2common_dict
        return g
    
    
    def __saveOutput(self ,g):
        family_name = self.outputPath.split('/')[-3] + "_" + self.outputPath.split('/')[-2]
        
        g.draw(self.outputPath +'/' +family_name +'_output.svg', format='svg',prog='dot')
        g.draw(self.outputPath +'/' +family_name +'_output.pdf', format='pdf',prog='dot')
        g.draw(self.outputPath +'/' +family_name +'_output.dot', format='dot',prog='dot')

        #SVG(filename= self.fam_path +'/' +family_name +'_output.svg')
#         print("---Save Output Done---")
    
    # entropy function
    def __entropy(self ,l):       
        e = 0
        n = sum(list(l))
        for i in l:
            e += float(i)/n*math.log(float(i)/n)
        return abs(e)
    
    
        
    # public function
    def getStageLen(self):
        return self.stageLen
    
    def getStage2node(self):
        return self.stage2node_dict
    
    def getStage2Hk(self):
        return self.stage2Hk_dict
    
    def getStage2common(self):
        return self.stage2common_dict
    
    def getStageGap(self):
        return self.gapLb
    
    def getGraph(self):
        return self.g
        

# class OutputMotif

In [None]:
class OutputMotiGraph:
    stageMatrix = None
    BASE = None 
    featurePro = None
    seg2lb = None
    lb2seg = None
    totalStageLen = None #--- 2/15
    gaplb = None #--- 2/15
    
    #stage2common = None
    stage2node = None
    stageLen = None
    g = None
    
    label_dict = None #--- 2/15
    
    
    #def __init__(self, stageMatrix, outputStage):
    def __init__(self, stageMatrix, BASE, featureProfile, outputStage, outputPath, Motif, label_dict): #---2/15
#         print("==Starting drawing motifs==")
        self.label_dict = label_dict #--- 2/15
        
        self.stageMatrix = stageMatrix
        
        self.BASE = BASE
        self.featurePro = featureProfile
        self.seg2lb = Motif.getMoti_seg2lb()
        self.lb2seg= Motif.getMoti_lb2seg()
        self.totalStageLen = Motif.getTotalStageLen() #---2/15
        self.gaplb = Motif.getMoti_gaplb() #---2/15
        
        #self.stage2common = outputStage.getStage2common()
        self.stage2node = outputStage.getStage2node()
        self.stageLen = outputStage.getStageLen()
        self.g = outputStage.getGraph()
        
        fam_name = outputPath.split('/')[-3] + "_" + outputPath.split('/')[-2]
        stageprofile_path = outputPath +'/'+fam_name+'_Stage- feature profile plot.png' 
        motifLenfile_path = outputPath +'/'+fam_name+'_MotifLen- probability CDF graph.png'
        stageDisMoti_path = outputPath +'/'+fam_name+"_Stage- Distinct Motif# graph.png"
        
        self.__setFeatureProPlot(stageprofile_path)
        self.__setMotiLenProbPlot(motifLenfile_path)
        self.__setDistinctMoti(stageDisMoti_path)
        
    
    #=== private function 
    def __setFeatureProPlot(self, stageprofile_path): #profile every hooklog (DNA)
        wid = float(self.totalStageLen)+2
        hei = float(len(self.stageMatrix)+2)     
        
        if hei/wid < 1:         # for automatically adjust the image
            if hei/wid <0.07:
                fig=plt.figure(num=None, figsize=(24, 24*(hei/wid)*12), dpi=80, facecolor='w', edgecolor='k')
            elif hei/wid <0.15:
                fig=plt.figure(num=None, figsize=(24, 24*(hei/wid)*6), dpi=80, facecolor='w', edgecolor='k')
            elif hei/wid <0.25:
                fig=plt.figure(num=None, figsize=(24, 24*(hei/wid)*4.5), dpi=80, facecolor='w', edgecolor='k')
            elif hei/wid <0.35:
                fig=plt.figure(num=None, figsize=(24, 24*(hei/wid)*3), dpi=80, facecolor='w', edgecolor='k')
            elif hei/wid <0.45:
                fig=plt.figure(num=None, figsize=(24, 24*(hei/wid)*1.5), dpi=80, facecolor='w', edgecolor='k')
            else:
                fig=plt.figure(num=None, figsize=(24, 24*(hei/wid)), dpi=80, facecolor='w', edgecolor='k')
        else:
            fig=plt.figure(num=None, figsize=(24, 20), dpi=80, facecolor='w', edgecolor='k')

        ax2 = plt.axes([0.0, 0.0, wid/100 , hei/100])    
        
        if len(self.stageMatrix) < 7: #--- 2/15 for fontsize
            plt.title('Stage - feature profile plot(high)', fontsize= 12) 
            plt.xlabel('stage index ('+str(self.totalStageLen)+')', fontsize= 8)
            plt.ylabel('feature profile index ('+str(len(self.stageMatrix))+')', fontsize= 8)
        else:
            plt.title('Stage - feature profile plot(high)' )
            plt.xlabel('stage index ('+str(self.totalStageLen)+')')
            plt.ylabel('feature profile index ('+str(len(self.stageMatrix))+')') 
            
        plt.xlim(0,self.totalStageLen+2) 
        plt.ylim(0,len(self.stageMatrix)+2)
        #plt.yticks( np.arange(0,int(hei),5), range(0,int(hei),5))
        if '.trace.hooklog' in list(self.stageMatrix.keys())[0]:
            #yticks_li = [self.label_dict[hk[:-14]] for hk in sorted(
            #        self.stageMatrix, key=lambda k: len(self.featurePro[k]), reverse=True)] #--- 2/24
            yticks_li = [self.label_dict[hk[:-14]] if hk[:-14] in self.label_dict else hk.split('.')[0]
                         for hk in sorted(
                    self.stageMatrix, key=lambda k: len(self.featurePro[k]), reverse=True) ] #--- 04/18 for common hk  
        else:
            yticks_li = [self.label_dict[hk] for hk in sorted(
                    self.stageMatrix, key=lambda k: len(self.featurePro[k]), reverse=True)] #--- 2/15
            
        plt.yticks( np.arange(1,int(hei)), yticks_li) #--- 2/15
        plt.xticks( np.arange(0,int(wid),5), range(0,int(wid),5))
        
        # for colorful weight--- 2/15
        stageMotifs = []
        for s in range(self.totalStageLen): stageMotifs.append({})
        for hk in self.stageMatrix:
            motifs = list(map(lambda s: self.seg2lb[tuple(s)], self.stageMatrix[hk])) ## WJ: python3 map workaround
            for s in range(self.totalStageLen):
                if motifs[s] == self.gaplb: continue  # skip a gap
                if motifs[s] in stageMotifs[s]: 
                    stageMotifs[s][ motifs[s] ] += 1
                else:
                    stageMotifs[s][ motifs[s] ] = 1

        for s in range(self.totalStageLen): # sorted by num of motifs
            stageMotifs[s] = [(k[0],k[1],len(self.lb2seg[k[0]])) for k in sorted(stageMotifs[s].items(), 
                                                                            key=lambda x : x[1], reverse=True) ]
        stageWeight = []
        for stage in stageMotifs: #---2/10
            weight = float(stage[0][1]) / len(self.stageMatrix)
            color = 'w'
            if weight == 1: color = 'r'
            elif weight >= float(2)/3: color = 'deepskyblue'
            elif weight >= float(1)/2: color = 'lightgreen'

            stageWeight.append([stage[0][0], color])

        #--- 2/15 end

        BASE_index = 0
        
        for l in range(self.totalStageLen): # for each stage---2/15
            for ii, lg in enumerate(sorted(self.stageMatrix, key=lambda k: len(self.featurePro[k]), reverse=True)):
                seg = self.stageMatrix[lg][l]
                if lg == self.BASE:
                    BASE_index = ii+1
#                 if self.stage2common[l]:
#                     if seg!=['=']:
#                         ax2.annotate(self.seg2lb[tuple(seg)][1:], xy=(l+1, ii+1), xycoords="data",
#                                      size='xx-small',va="center", ha="center", 
#                                      bbox=dict(boxstyle="square", fc="r"))
#                 elif seg!=['=']:
#                         ax2.annotate(self.seg2lb[tuple(seg)][1:], xy=(l+1, ii+1), xycoords="data",
#                                      size='xx-small',va="center", ha="center", 
#                                      bbox=dict(boxstyle="square", fc="w"))

                # for colorful weight--- 2/15
                if seg!=['=']: 
                    if self.seg2lb[tuple(seg)] == stageWeight[l][0]:
                        ax2.annotate(self.seg2lb[tuple(seg)][1:], xy=(l+1, ii+1), xycoords="data",
                                     size='xx-small',va="center", ha="center", 
                                     bbox=dict(boxstyle="square", fc=stageWeight[l][1])) 
                    else:
                        ax2.annotate(self.seg2lb[tuple(seg)][1:], xy=(l+1, ii+1), xycoords="data",
                                     size='xx-small',va="center", ha="center", 
                                     bbox=dict(boxstyle="square", fc="w")) 
                #--- 2/15 end

        plt.plot([0, self.totalStageLen+2], [BASE_index, BASE_index], 'b--')
        plt.savefig(stageprofile_path, dpi=300, bbox_inches='tight')
        
#         print("---Save featureProfile Plot done---")

    
    def __setMotiLenProbPlot(self, motifLenfile_path):  # Motif length - probability CDF plot
        plt.figure(num=None, figsize=(12, 4), dpi=80, facecolor='w', edgecolor='k')
        
        n_bins = 100
        n, bins, patches = plt.hist([ len(seg) for hk in self.stageMatrix for seg in self.stageMatrix[hk] ],
                                    n_bins, normed=1, histtype='step', cumulative=True)
        plt.ylim(0, 1.05)
        plt.xlim(0, max([len(self.lb2seg[lb]) for lb in self.lb2seg])+5)
        plt.title('Motif length - probability CDF graph')
        plt.xlabel('motif length')
        plt.ylabel('probability')
        plt.savefig(motifLenfile_path , dpi=300, bbox_inches='tight')
        
#         print("---Save MotifLen Prob Plot done---")
        
    
    def __setDistinctMoti(self, stageDisMoti_path):   
        wid = float(self.totalStageLen +2) 
        
        plt.figure(num=None, figsize=(12, 4), dpi=80, facecolor='w', edgecolor='k')
        plt.title('Stage - distinct motif # graph')
        plt.xlabel('stage index ('+str(self.stageLen)+')')
        plt.ylabel('# of distinct motif')
        moti_notGap = [len(self.stage2node[k]) for k in sorted(self.stage2node)]
        moti_gap= [len([n for n in self.stage2node[k] if 'gap' in self.g.get_node(n).attr['label']])
                   for k in sorted(self.stage2node)]
        plt.vlines(range(1,self.stageLen+ 1),[0],moti_notGap,colors='red',linewidth=3,
                   label='non-gap motif number')
        plt.vlines(range(1,self.stageLen+ 1),[0],moti_gap,colors='b',linewidth=3,label='gap number')

        if int(wid)<10:x_interval=1
        elif int(wid)<100:x_interval=5
        else: x_interval=100
        plt.xticks(range(0,int(wid),x_interval))

        max_hei = max(moti_notGap)
        if int(max_hei)<10:y_interval=1
        elif int(max_hei)<100:y_interval=5
        else: y_interval=100
        plt.yticks(range(0,int(max_hei)+1 ,y_interval))

        plt.legend(loc='upper right')
        plt.savefig(stageDisMoti_path, dpi=1000, bbox_inches='tight')
        
#         print("---Save Distinct motif Plot done---")
    