In [30]:
import pandas as pd
import sys
import os
import argparse
from matplotlib import cm
import graph_tool.all as gt
import numpy as np
import matplotlib as mpl
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap
from tqdm import tqdm


class Network(object):


    def __init__(self,
                 input_dotfile,
                 cmap=cm.Spectral_r,
                 edgecolmap=None,
                 preflen=None,
                 alpha=.7,
                 minsize=5,
                 edgealpha=.2,
                 exponent=1,
                 edgecollim=2,
                 strongcomponent=True,
                 removeselfloops=True,
                 exponentialscaling=True,
                 outfile=None):
        if edgecolmap is None:
            self.edgecolmap=LinearSegmentedColormap.from_list(
                'edgecolmap',
                ['#0000ff',
                 '#888888',
                 '#ff0000'])
        else:
            self.edgecolmap=edgecolmap
        self.minsize=minsize
        self.cmap=cmap
        self.network=None
        self.dotfile=input_dotfile
        self.preflen=preflen
        self.edgealpha=edgealpha
        self.exponent=exponent
        self.edgecollim=edgecollim
        self.strongcomponent=strongcomponent
        self.removeselfloops=removeselfloops
        self.exponentialscaling=exponentialscaling
        self.alpha=alpha
        self.nm=None
        self.od=None
        self.pos=None
        self.hl=None
        self.ew_pen=None
        self.e_marker=None
        self.deg=None
        self.control=None
        self.ods=None
        self.varclass=None
        self.ecol=None
        self.outfile=outfile
        
    def short_name(self,s,LEN):
        '''
        return FirmicutesA/C/K is available
        '''
        if any([s.split('_')[0][:4] in x for x in ['Firmicutes','Chloroflex','Desulfobacterota','Bacilli']]):
            if len(s.split('_'))>1 and s.split('_')[1] in ['A','C','K','I']:
                return s.split('_')[0][:LEN]+'_'+s.split('_')[1]
            else:
                return s.split('_')[0][:LEN]
        else:
            return s.split('_')[0][:LEN]
        
    def f(self,x,A=0,E=True,exponent=2.0):
        '''
        adjust node sizes
        '''
        if E:
            return exponent**x + A
        return x+A

    def sfunc(self,val,SIGN=False):
        if SIGN:
            return np.sign(val)
        return val

 
    def get(self):
        self.network = gt.load_graph(self.dotfile)

        if self.strongcomponent:
            self.network=gt.extract_largest_component(
                self.network, directed=True, prune=True)

        if self.removeselfloops:
            gt.remove_self_loops(self.network)

        self.nm = self.network.new_vertex_property("string")
        nm2 = self.network.new_vertex_property("string")
        self.hl = self.network.new_vertex_property("bool")
        self.network.vertex_properties["text"] = self.nm
        self.network.vertex_properties["text"] = nm2
        names=[]
        for v in self.network.vertices():
            if v.out_degree() > -1:
                self.nm[v]=self.short_name(
                    self.network.vp.vertex_name[v],self.preflen)
                nm2[v]=self.short_name(
                    self.network.vp.vertex_name[v],self.preflen)
                self.hl[v]=False
            else:
                nm2[v]=self.short_name(
                    self.network.vp.vertex_name[v],self.preflen)
                self.nm[v]=''
                self.hl[v]=False
            names=names+[nm2[v]]

        NAMES=pd.Series(list(set(names)),
                        name='varclass').reset_index().set_index('varclass')
        self.varclass = self.network.new_vertex_property("float")
        self.network.vertex_properties["varclass"] = self.varclass
        for v in self.network.vertices():
            self.varclass[v]=NAMES.loc[nm2[v]].values[0]

        self.od = self.network.new_vertex_property("float")
        self.network.vertex_properties["size"] = self.od
        for v in self.network.vertices():
            self.od[v]=self.f(v.out_degree(),
                              A=self.minsize,
                              E=self.exponentialscaling,
                              exponent=self.exponent)+5
        self.ods = self.network.new_vertex_property("float")
        self.network.vertex_properties["size"] = self.ods
        for v in self.network.vertices():
            self.ods[v]=1*self.f(v.out_degree(),
                                 A=self.minsize,
                                 E=self.exponentialscaling,
                                 exponent=1)+2

        self.ew = self.network.new_edge_property("float")
        self.network.edge_properties["eweight"] = self.ew
        for e in self.network.edges():
            self.ew[e]=float(self.network.ep.weight[e])**1

        self.ew_pen = self.network.new_edge_property("float")
        self.network.edge_properties["eweight_pen"] = self.ew_pen
        for e in self.network.edges():
            self.ew_pen[e]=4/(1 + np.exp(-.05-np.fabs(float(self.network.ep.weight[e]))))

        self.e_marker = self.network.new_edge_property("string")
        self.network.edge_properties["e_marker"] = self.e_marker
        for e in self.network.edges():
            if float(self.network.ep.weight[e]) < 0:
                self.e_marker[e]='bar'
            else:
                self.e_marker[e]='arrow'

        self.deg = self.network.degree_property_map("out")

        self.ecol = self.network.new_edge_property("vector<double>")
        self.network.edge_properties["ecol"] = self.ecol
        for e in self.network.edges():
            col=cm.ScalarMappable(mpl.colors.Normalize(vmin=-self.edgecollim,
                                                       vmax=self.edgecollim),
                                  cmap=self.edgecolmap).to_rgba(float(self.ew[e]))
            col=list(col)
            col[3]=self.edgealpha
            self.ecol[e]=tuple(col)

        self.pos = gt.graphviz_draw(self.network,
                                    overlap=False,
                                    vsize=20,
                                    sep=20,
                                    output=None)

        self.control = self.network.new_edge_property("vector<double>")
        for e in self.network.edges():
            d = np.sqrt(np.sum((self.pos[e.source()].a
                                - self.pos[e.target()].a) ** 2))
            d=d/2
            self.control[e] = [0.0,0.0,0, .2*d, 0.5, d,1,0]

        if self.outfile is not None:
            gt.graph_draw(self.network,nodesfirst=False,
                          pos=self.pos,
                          vertex_halo=self.hl,
                          vertex_halo_color=[.2,.2,.2,.1],
                          edge_pen_width=self.ew_pen,
                          edge_end_marker=self.e_marker,
                          vorder=self.deg,
                          edge_marker_size=10,
                          vertex_color=self.varclass,#[.5,.5,.5,.3],
                          edge_color=self.ecol,#[.5,.5,.5,.5],
                          vertex_pen_width=1.5,
                          vertex_size=self.od,
                          vertex_text=self.nm,
                          vcmap=(self.cmap,self.alpha),
                          edge_control_points=self.control,
                          vertex_fill_color=self.varclass,#deg,
                          vertex_font_size=self.ods,
                          vertex_text_color=[.1,.1,.1,.8],
                          #vertex_text_position=0,
                          output=self.outfile)



In [31]:
network=Network('./hypothesis_phylum_31_33_-1.dot',outfile='tmp.pdf')

In [32]:
network.get()