In [278]:
import pandas as pd
import seaborn as sns
import patsy
import networkx as nx
from networkx.readwrite.graphml import read_graphml
from networkx.algorithms.shortest_paths.generic import shortest_path_length
import numpy as np
from copy import copy
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import LogNorm
from matplotlib.colors import ListedColormap
plt.rcParams['svg.fonttype']='none'
plt.rcParams['text.usetex']=False
plt.rcParams['font.family']='sans'
from glob import glob

## Gather results for different scenarios

In [17]:
## a more streamlined implementation of the function in the attractors script
def attractors_1state_from_txt(txt_name, network_name, cleanup=True):
    with open(network_name+'.bnet',encoding='utf-8') as fo:
        nodes_bnet = [ln.strip().split(',')[0] for ln in fo]
    with open(network_name+'.cnet',encoding='utf-8') as fo:
        ln0=fo.readline()
    ln0 = ln0.replace('\n', "").replace('#', "").strip()
    nodes_cnet = ln0.split(', ')
    idx = pd.Series(dict([(vv,kk) for kk,vv in enumerate(nodes_cnet)])).loc[nodes_bnet]
    attractors = list()
    with open(txt_name,encoding='utf-8') as fo:
        p=fo.readlines()
    current_attractor = []
    sizes = []
    for line in p[1:]:
        # Strip line
        cleanline = line.strip().replace('\n', "")
        #print("{:d}: '{:s}'".format(i ,cleanline))
        if cleanline.startswith('Attractor'):
            attractors.append(current_attractor[0])
            current_attractor = []
            size = cleanline[-1]            
            sizes.append(size)
        elif 'average' in cleanline:
            pass
        elif len(cleanline) > 0 and cleanline.startswith(('0','1')):
            cleanline = np.array([int(elt) for elt in cleanline])
            current_attractor.append(cleanline)
    attractors = np.c_[attractors].astype(int)[:,idx.values]
    return pd.DataFrame(attractors,columns=idx.index,index=range(1,attractors.shape[0]+1))
    #np.savetxt('./attfiles/1st_%s.csv' % (network_name.split('/')[-1]), attractors, delimiter = ',') 
    #return attractors


In [None]:
irrev_d = {}
avgstate_d = {}
numattr_d = {}
for aa,pt in enumerate(['KO','OE']):
    for bb,srt in enumerate(['desc','asc']):
        fn_l = glob("results/result-%s-twoparam_*_*_%s_*.csv" % (pt,srt))
        for fn in fn_l:
            __,ss,rr,__,rep,__ = fn.split('/')[-1].split('_')
            att_fn = "attfiles/1st_twoparam_%s_%s_%s_%s_%d.csv" % (ss,rr,srt,rep,int(rep))
            if not osp.exists(att_fn):
                print('missing',att_fn)
                continue
            attractors_df = pd.read_csv(att_fn,header=None)
            if srt == 'asc':
                attractors_df.columns = asc_l
            else:
                attractors_df.columns = desc_l
            attractors_df.index = [elt for elt in range(1,attractors_df.shape[0]+1)]
            results = pd.read_csv("results/result-%s-twoparam_%s_%s_%s_%s_%d-pre.csv" % (pt,ss,rr,srt,rep,int(rep)),index_col=0)
            results = results.iloc[:-2]
            if attractors_df.shape[0]==results.shape[0]:
                results = results.loc[attractors_df.index]
            irrev_d[(srt,pt,ss,rr,rep)] = results.mean(axis=0) ## the mean is taken across attractors
            if pt=='KO': ## this does not depend on the perturbation type
                avgstate_d[(srt,ss,rr,rep)] = attractors_df.loc[results.index].mean(axis=0)
                numattr_d[(srt,ss,rr,rep)] = attractors_df.shape[0]

DF = pd.DataFrame(irrev_d.keys())
## unit test to confirm that the attractors match 
##     for knockouts, everywhere that there is a "0" in the attractor with have an NaN entry in the irreversibility
for nn,row in results.loc[:257,attractors_df.columns].iterrows():
    if nn>256:
        continue
    inds1 = np.where(row.isna())[0]
    inds2 = np.where(attractors_df.loc[nn]==0)[0]
    if not np.all(inds1==inds2):
        break
else:
    print("Test passed")

ppairs_l = []
for kk in irrev_d.keys():
    ppairs_l.append((kk[-3],kk[-2]))

ens_avg_irrev = {}
ens_avg_state = {}
ens_avg_nattr = {}
for (rr,ss) in set(ppairs_l):
    print(rr,ss)
    for srt in ['asc','desc']:
        print(srt)
        sel_keys = [kk for kk in irrev_d.keys() if kk[0]==srt and kk[1]=='KO' and kk[2]==rr and kk[3]==ss]
        rep_irrev_d = {}
        rep_state_d = {}
        rep_numattr_d = {}
        rep_nz_d = {}
        test_d = {}
        for sk in sel_keys:
            srt,pt,rr,ss,rep = sk
            okey = (srt,'OE',rr,ss,rep)
            avg_key = (srt,rr,ss,rep)
            state = avgstate_d[avg_key]
            num = numattr_d[avg_key]
            ptirrev_df = pd.concat({'KO':irrev_d[sk],'OE':irrev_d[okey]},axis=1)
            avg_irrev = pd.concat({'KO':(ptirrev_df.KO*state),'OE':(1-state)*ptirrev_df.OE},axis=1).sum(axis=1)
            rep_irrev_d[int(rep)]=avg_irrev
            rep_state_d[int(rep)]=state
            rep_nz_d[int(rep)]=avg_irrev.index[np.where(avg_irrev>0)[0]].unique().tolist()
            rep_numattr_d[int(rep)] = num
            KOvOE = np.abs(irrev_d[sk]-irrev_d[okey]).loc[selg].sort_values()
            test_d[rep]=KOvOE.index[-10:].tolist()
            #Diff = state.loc[selg]-0.5
            #al = 0.8 if int(rep) < 10 else 0.2
            #plt.scatter(KOvOE,Diff,label=rep,alpha=al,color='C%d' % (int(rep)%10))
        irrev_df=pd.DataFrame(rep_irrev_d)
        ens_avg_irrev[(rr,ss,srt)] = irrev_df.mean(axis=1)
        state_df = pd.DataFrame(rep_state_d)
        ens_avg_state[(rr,ss,srt)] = state_df.mean(axis=1)
        numattr_ser = pd.Series(rep_numattr_d)
        ens_avg_nattr[(rr,ss,srt)] = numattr_ser
        nz_ser = pd.Series(rep_nz_d)

## this list of genes is the set that are potentially irreversible
selg=['crp','phoB','hns','cra','fis','leuO','fnr',
     'rhaS','gadE','rhaR','bglJ','galS','lrp',
     'fur','rcsB','galR','stpA','arcA','gadX','fliZ',
     'cspA','uxuR','gadW','ydeO','xylR','exuR','mlrA',
     'fhlA','oxyR','soxS','rcsA','evgA','srlR','marR',
     'rob','narL','adiY','dcuR','soxR','ptsG','csgD',
     'mazE','flhC','mlc','flhD','mazF','ompR','gutM',
     'yjjQ','marA','hdfR']
INDS = pd.DataFrame(ens_avg_irrev).mean(axis=1).loc[selg].sort_values(ascending=False).index

## Graphml file for Fig. 4 

### Number of genes reached in each origon

In [None]:
fn = 'input_files/generegulation_tmp.txt'
tab = pd.read_table(fn,skiprows=11,header=None)
cols = []
with open(fn,'r') as fh:
    ln = fh.readline().strip()
    while ln.startswith('#'):
        if ') ' in ln:
            cols.append(ln.split(') ')[1])
        ln = fh.readline().strip()
tab.columns = cols

df = tab
df_sc = df.loc[:,['GENE_NAME_REGULATOR','GENE_NAME_REGULATED','GENEREGULATION_FUNCTION']]
nodes = []
edges = []
sdef_edges = df_sc[df_sc.GENEREGULATION_FUNCTION.isin(['activator','repressor'])]
G = nx.DiGraph()
G.add_weighted_edges_from([(a,b,1 if c=='activator' else -1) for __,(a,b,c) in sdef_edges.iterrows()])
#%%
#sizes of strongly connected components
sccs = sorted(nx.strongly_connected_components(G),key=len, reverse=True)
wccs = sorted(nx.weakly_connected_components(G),key=len, reverse=True)
len_sccs = [len(c) for c in sorted(nx.strongly_connected_components(G),key=len, reverse=True)]    
len_wccs = [len(c) for c in sorted(nx.weakly_connected_components(G),key=len, reverse=True)] 

#%%
nums_reached = {}
nodes_reached = {}
## find paths
for uu in G.nodes:
    reached = [vv for vv in G.nodes if nx.has_path(G,uu,vv)]
    nums_reached[uu]=len(reached)
    nodes_reached[uu]=reached
nums_reached_ser = pd.Series(nums_reached)
nums_reached_ser.sort_values()

### Pruning to the core network

In [None]:
def simplify_network(G):
    """Reduces network G by pruning all nodes of zero out-degree."""
    simplified_nodes = []
    for node in G.nodes:
        in_degree = G.in_degree(node)
        out_degree = G.out_degree(node)
        if out_degree == 0:
            continue
        else:
            simplified_nodes.append(node)
    G_simplified = G.subgraph(simplified_nodes)
    print(len(simplified_nodes))
    return G_simplified

## start at node that reaches the most nodes
nd_reached_most =nums_reached_ser.idxmax() ## phoB
reached_most = nodes_reached[nd_reached_most] + [nd_reached_most]
G_wcc = G.subgraph(list(wccs[0])) ## largest weakly connected component
G_reach = G.subgraph(reached_most) ## largest origon
for ii,net in enumerate([G_reach,G_wcc]):
    LL = 1e10
    LLp = len(net.nodes)
    while LL > LLp:
        net = simplify_network(net)
        LL = LLp
        LLp = len(net.nodes)
    if ii==0:
        nx.write_gml(net,'networks/rs2.gml')
    else:
        pass

### Generation of Fig. 4

In [None]:
gnavg_irrev = pd.DataFrame(ens_avg_irrev).mean(axis=1)
lnorm=LogNorm(vmin=1e-3,vmax=1e0)
G_rs2 = nx.read_gml('./networks/rs2.gml')
nodes_rs2 = list(G_rs2.nodes())

nodes_all = list(gnavg_irrev.index)
irr_all = list(gnavg_irrev.values)
for i in nodes_rs2:
    if i not in nodes_all:
        nodes_all.append(i)
        irr_all.append(0)
overall = dict(zip(nodes_all,irr_all))
nodecolor_d = {}
for nd in G_rs2.nodes():
    iprob = gnavg_irrev.loc[nd]
    if iprob<1e-3:
        nodecolor_d[nd]='#777777ff'
    else:
        num_l = my_cmap(lnorm(iprob))
        clr_str = ''.join([str(hex(int(elt*256-1e-6)))[-2:] for elt in num_l])
        nodecolor_d[nd]='#%s' % clr_str
        
nx.set_node_attributes(G_rs2, nodecolor_d, 'nodecolor')
nx.set_node_attributes(G_rs2, dict(zip(nodes_all,nodes_all)), 'name')
nx.write_gml(G_rs2, './networks/rs2_irr_neg2_newx.gml')
nx.write_graphml(G_rs2, './networks/rs2_irr_neg2_newx.graphml') 

### Lower bound of the size of the set of possible rules

In [None]:
tot = 1
for nd in net.nodes():
    kplus = len(list(net.predecessors(nd)))
    tot *= 2**(kplus-1)
print(tot)

In [None]:
np.log10(float(tot))

## Fig. 5

In [None]:
irrev_by_insrt = pd.DataFrame(ens_avg_irrev).groupby(level=2,axis=1).mean().loc[selg]

def path_weight(G,path):
    weight = 1
    for node in path[1:]:
        weight *= 1/G.in_degree(node)
    return weight

def path_sign(G,path):
    sign = 1
    for i in range(len(path)-1):
        edge = (path[i],path[i+1])
        sign *= G.get_edge_data(edge[0],edge[1])['weight']
    return sign
    
def num_scc_weighted(G,group):
    sccs = list(nx.strongly_connected_components(G))
    num_edge_scc = []
    for n in group:  
        num_edge = 0
        for s in sccs:
            if len(s) == 1 and list(G_rs2.out_edges(list(s)[0],data=True))[0][2]['weight'] == -1:
                continue
            if set(n) != s:
                for n_s in s:
                    if n_s != n:
                        try:
                            length=nx.shortest_path_length(G,source=n,target=n_s)
                            if length == 0:
                                shortest_paths = list(nx.shortest_simple_paths(G,source=n,target=n_s))
                                if len(shortest_paths) != 1:
                                    length = len(shortest_paths[1])
                                else:
                                    continue
                            if length <= max_length:
                                num_edge += 1/length
                                if length == 0:
                                    print(n,n_s)
                                break
                        except nx.NetworkXNoPath:
                            continue
        num_edge_scc.append(num_edge)
    return num_edge_scc

def num_path_scc_weighted(G,group):
    sccs = list(nx.strongly_connected_components(G))
    num_edge_scc = []
    for n in group:  
        num_edge = 0
        for s in sccs:
            if len(s) == 1 and list(G_rs2.out_edges(list(s)[0],data=True))[0][2]['weight'] == -1:
                continue
            if set(n) != s:
                for n_s in s:
                    if n_s != n:
                        try:
                            paths = list(nx.all_simple_paths(G, n, n_s, cutoff=max_length))
                            for path in paths:
                                weight = path_weight(G,path)
                                sign = path_sign(G,path)
                                num_edge += abs(weight*sign)
                        except nx.NetworkXNoPath:
                            continue
        num_edge_scc.append(abs(num_edge))
    return num_edge_scc


max_length = 4

path_scc_pos = num_path_scc_weighted(G_rs2,irrev_by_insrt.index)

fig,ax_arr = plt.subplots(1,2,figsize=(3.375,2),sharey=True,sharex=True)
for nm,col in irrev_by_insrt.items():
    ax = ax_arr[0] if nm=='asc' else ax_arr[1]
    title = 'Ascending' if nm=='asc' else 'Descending'
    ax.scatter(path_scc_pos,col, c=col.values, norm=lnorm, cmap=my_cmap, alpha=0.9, s=16)
    mm,bb = np.polyfit(np.log(np.asarray(path_scc_pos)[col>0]),np.log(col[col>0]),deg=1)
    xvals = np.logspace(np.log10(np.amin(path_scc_pos)),np.log10(np.amax(path_scc_pos)))
    yvals = xvals**mm * np.exp(bb)
    rr22 = 1 - np.linalg.norm((np.asarray(path_scc_pos)[col>0])**mm*np.exp(bb)-col[col>0])/np.linalg.norm(col[col>0])
    print(rr22)
    ax.plot(xvals,yvals,ls=':',color='C7',alpha=0.5)
    ax.set_xlabel('Weighted no. of paths to SCCs',fontsize=8)
    ax.set_title(title,fontsize=8)
    plt.setp(ax.get_xticklabels(),size=6)
    ax.set_yscale('log')
    ax.set_xscale('log')
    ax.annotate('$R^2$=%.02f' % rr22, (3,4e-3),fontsize=6,fontstyle='italic')
    for ii,txt in enumerate(irrev_by_insrt.index[:11]):
         ax.annotate(txt,(path_scc_pos[ii],col.iloc[ii]),fontsize=6,fontstyle='italic')

plt.setp(ax_arr[0].get_yticklabels(),size=6)
ax_arr[0].set_ylabel('Irreversibility',fontsize=8)
fig.savefig('figs/new_fig5.svg')


### Signficance of the trends

In [None]:
from scipy.stats import linregress
linregress(np.log(np.asarray(path_scc_pos)[irrev_by_insrt.asc>0]),np.log(irrev_by_insrt.asc[irrev_by_insrt.asc>0]))

In [None]:
linregress(np.log(np.asarray(path_scc_pos)[irrev_by_insrt.desc>0]),np.log(irrev_by_insrt.desc[irrev_by_insrt.desc>0]))

## Fig. 6

In [None]:
my_cmap = copy(mpl.colormaps['Spectral_r']) # copy the default cmap
my_cmap.set_bad('C7')
my_cmap.set_under('w')

MIN_OFF = 1e-6
def irr_imshow(ax,data,inds,plot_xlabels=False,plot_ylabels=False,xlbl='',ylbl='',xtls=[]):
    im = ax.imshow(data.loc[inds]+MIN_OFF,norm=lnorm,cmap=my_cmap)
    ax.set_yticks(np.arange(len(inds)))
    ax.set_xticks(np.arange(6))
    if plot_ylabels:
        ax.set_yticklabels(['%s' % ind for ind in inds],size=5,fontstyle='italic')
        ax.set_ylabel(ylbl,size=6)
    if plot_xlabels:
        ax.set_xticklabels(xtls,size=5,rotation=90,horizontalalignment='center')
        ax.set_xlabel(xlbl,size=6)
    return im,ax
    
CUT1 = DF[DF.loc[:,2]=='0.00']
CUT2 = DF[(DF.loc[:,3]=='1.00')|((DF.loc[:,3]=='0.00') & (DF.loc[:,2]=='0.00'))]
CUT3 = DF[(DF.loc[:,2].astype(float)+DF.loc[:,3].astype(float)==1)|((DF.loc[:,3]=='0.00')&(DF.loc[:,2]=='0.00'))]

inds_l = np.split(INDS.get_level_values(0),[25])
for iii,inds in enumerate(inds_l):
    lnorm=LogNorm(vmin=1e-2,vmax=1e0) if iii==0 else LogNorm(vmin=1e-3,vmax=1e0)    
    fig,ax_arr = plt.subplots(4,3,figsize=(2.25,10),sharex=True,sharey=True)
    for jj,CUT in enumerate([CUT1,CUT2,CUT3]):
        repavg_l = []
        for kw,grp in CUT.groupby([0,1,2,3]):
            sel_keys = [tuple(row.values) for dum,row in grp.iterrows()]
            sel_d = {}
            for ii in range(len(sel_keys)):
                if jj<1:
                    sel_d[(sel_keys[ii][0],sel_keys[ii][1],sel_keys[ii][3],int(sel_keys[ii][4]))]=irrev_d[sel_keys[ii]]
                else:
                    if sel_keys[ii][3]=='0.00':
                        sel_d[(sel_keys[ii][0],sel_keys[ii][1],'1.00',int(sel_keys[ii][4]))]=irrev_d[sel_keys[ii]]
                    else:
                        sel_d[(sel_keys[ii][0],sel_keys[ii][1],sel_keys[ii][2],int(sel_keys[ii][4]))]=irrev_d[sel_keys[ii]]                    
                
            repavg_irrev = pd.DataFrame(sel_d).groupby(level=[0,1,2],axis=1).mean()
            repavg_l.append(repavg_irrev)
        cut_data = pd.concat(repavg_l,axis=1).loc[selg]
        vals = cut_data.loc[inds].values.ravel()

        for kk,srt in enumerate(['asc','desc']):
            for ll,pt in enumerate(['KO','OE']):
                ROW = kk*2 +ll
                ax = ax_arr[ROW,jj]
                plot_data = cut_data.xs((srt,pt),level=[0,1],axis=1).loc[inds]
                rev_srt_ord = True if jj<1 else False
                srt_cols = sorted(plot_data.columns,key=lambda x: float(x),reverse=rev_srt_ord)
                if jj==0:
                    pylbl = True
                    srtlbl = 'Asc.' if srt=='asc' else 'Desc.'
                    ylbl = 'Gene, Pert: %s, Sort: %s ' % (pt,srtlbl)
                else:
                    pylbl = False
                    ylbl = ''
                if ROW==3:
                    pxlbl = True
                    xtls = srt_cols
                    if jj==0:
                        xlbl='r, s=0'
                    elif jj==1:
                        xlbl='s, r=1'
                    else:
                        xlbl = 's, r=1-s'
                else:
                    pxlbl = False
                    xlbl = ''
                img,ax = irr_imshow(ax,plot_data.loc[:,srt_cols],inds,plot_xlabels=pxlbl,plot_ylabels=pylbl,xlbl=xlbl,ylbl=ylbl,xtls=xtls)
    fig.subplots_adjust(top=0.925)
    PTS = ax.get_position().get_points()
    bttm = PTS[0,1]
    tp = PTS[1,1]
    cbax = fig.add_subplot([0.15,0.93,0.73,0.01])
    cbar = fig.colorbar(img,cax=cbax,orientation='horizontal')
    cbar.ax.xaxis.set_ticks_position('top')
    cbar.ax.xaxis.set_label_position('top')
    cbar.set_label('Irreversibility',size=8)
    plt.setp(cbar.ax.xaxis.get_ticklabels(),size=6)
    torblbl = 'top_genes' if iii==0 else 'bottom_genes'
    fig.savefig('figs/irreversibility_by_case_%s.svg' % torblbl)


## Fig. 3 (Properties of the rules as a function of $r$ and $s$)

In [None]:
import sympy
from sympy.abc import r,s
from sympy.logic.boolalg import truth_table
from collections import defaultdict

def generate_possible_rules_probabilities_new(unas,expr='',prob=1,paren=0):
    if len(unas)==1:
        ## termination condition
        expr+=  ('%s '% (unas[0])) +')'*paren
        return [([],expr,prob,0)]
    nexpr1 = expr + ' %s & ( ' % (unas[0])
    nprob1 = prob*r
    nparen1 = paren + 1
    nexpr21 = expr + ' %s | ' % (unas[0])
    nprob21 = prob*(1-r)*s
    nparen21 = paren
    nexpr221 = expr + ' %s & ' % (unas[0])
    nprob221 = prob*(1-r)*(1-s)
    nparen221 = paren
    return generate_possible_rules_probabilities_new(unas[1:],nexpr1,nprob1,nparen1)+ \
           generate_possible_rules_probabilities_new(unas[1:],nexpr21,nprob21,nparen21) + \
           generate_possible_rules_probabilities_new(unas[1:],nexpr221,nprob221,nparen221)

def parse_rule_attrs(rstr, ch='a'):
    L = rstr.split('%s_' % ch)
    cd = 1
    cd_d = defaultdict(list)
    inp_l = []
    for elt in L:
        if elt=='': ## string starts with no parentheses
            continue
        elif elt.startswith('('): ## string starts with parentheses (at most 1)
            cd+=1
        elif '& (' in elt:
            inp_l.append(elt.split(' &')[0])
            cd_d[cd].extend(inp_l)
            inp_l=[]
            cd += elt.count('(')
        elif ') | (' in elt:
            inp_l.append(elt.split(')')[0])
            cd_d[cd].extend(inp_l)
            inp_l=[]
        elif '| (' in elt:
            inp_l.append(elt.split(' |')[0])
            cd_d[cd].extend(inp_l)
            inp_l=[]
            cd +=2
        elif ' | ' in elt: ## this case MUST COME AFTER ') | ('
            ## inp_l will always be empty in this case
            inp_l.append(elt.split(' | ')[0])
            cd_d[cd].extend(inp_l)
            inp_l=[]
        elif ' & ' in elt:
            inp_l.append(elt.split(' & ')[0])
        elif elt.endswith(')'): ## last term
            inp_l.append(elt.split(')')[0])
            cd_d[cd].extend(inp_l)
            inp_l=[]
        else:
            inp_l.append(elt)
            cd_d[cd].extend(inp_l)
            inp_l=[]
    return dict(cd_d)

In [None]:
## generate the possible rules for networks up to size 7

myexpr_l = []
bias_poly_l = []
for tsz in range(2,8):
    print(tsz)
    avec = sympy.symarray(a,tsz)
    __,RULES,PROBS,__ = zip(*generate_possible_rules_probabilities_new(avec.tolist(),expr='',prob=1,paren=0))
    sympRULES = [sympy.sympify(RR) for RR in RULES]
    rule_cats = pd.Categorical([str(RR) for RR in sympRULES])
    unique_rules = rule_cats.value_counts().index
    rule_prob_l = []
    for RR in unique_rules:
        cd_d = parse_rule_attrs(RR)
        TF = rule_cats==RR
        cPROB = 0
        for ii in range(len(PROBS)):
            if TF[ii]:
                cPROB+=PROBS[ii]
                
        rule_prob_l.append((RR,sympy.simplify(cPROB),cd_d))
    myexpr=0
    bias_l = []
    for expr,prob,cd_d in rule_prob_l:
        TT_df = pd.DataFrame([np.r_[bvars,op==sympy.true] for bvars,op in truth_table(expr,avec)])
        bias = TT_df.iloc[:,-1].sum()
        bias_l.append(bias*prob)
        ## code to calculate the average canalizing depth
        wt_sum = 0
        totL = 0
        for kk,vv in cd_d.items():
            wt_sum += kk*len(vv)
            totL+=len(vv)
        myexpr+=prob*sympy.Rational(wt_sum,totL) ## this now gives the expected canalizing depth
    bias_poly_l.append(sum(bias_l))
    myexpr_l.append(myexpr)
    


#### GENERATE CODE:

In [None]:
from sympy import pycode
for ii,expr in enumerate(myexpr_l):
    print("def ex_cdepth%d(r,s):\n    return %s"% (ii+2,sympy.simplify(expr)))
    print("")
print("")
print("")
print("")
for ii,bias in enumerate(bias_poly_l):
    print("def ex_bias%d(r,s):\n    return %s"% (ii+2,sympy.simplify(bias)))
    print("")    

#### Paste code into cell below to define the polynomial functions for plotting:

In [None]:
def ex_cdepth2(r,s):
    return 1

def ex_cdepth3(r,s):
    return -8*r**2*s**2/3 + 2*r**2*s/3 + 16*r*s**2/3 - 10*r*s/3 - 8*s**2/3 + 8*s/3 + 1

def ex_cdepth4(r,s):
    return -r**3*s**3 + 3*r**3*s**2/4 + r**3*s/4 + 3*r**2*s**3 - 13*r**2*s**2/2 - \
           r**2*s/4 - 3*r*s**3 + 43*r*s**2/4 - 4*r*s + s**3 - 5*s**2 + 4*s + 1

def ex_cdepth5(r,s):
    return 2*r**4*s**3/5 - 2*r**4*s/5 - 16*r**3*s**3/5 + 8*r**3*s**2/5 + 8*r**3*s/5 + \
           36*r**2*s**3/5 - 52*r**2*s**2/5 - 11*r**2*s/5 - 32*r*s**3/5 + 16*r*s**2 - \
           21*r*s/5 + 2*s**3 - 36*s**2/5 + 26*s/5 + 1

def ex_cdepth6(r,s):
    return r**5*s**4/6 + r**5*s**3/6 - 5*r**5*s**2/6 + r**5*s/2 - 5*r**4*s**4/3 + \
           13*r**4*s**3/6 + 13*r**4*s**2/6 - 8*r**4*s/3 + 5*r**3*s**4 - 25*r**3*s**3/2 + \
           2*r**3*s**2 + 11*r**3*s/2 - 20*r**2*s**4/3 + 137*r**2*s**3/6 - 35*r**2*s**2/2 - \
           17*r**2*s/3 + 25*r*s**4/6 - 53*r*s**3/3 + 49*r*s**2/2 - 4*r*s - s**4 + 5*s**3 - \
           31*s**2/3 + 19*s/3 + 1

def ex_cdepth7(r,s):
    return -r**6*s**6 + 19*r**6*s**5/7 - 17*r**6*s**4/7 + r**6*s**3/7 + 8*r**6*s**2/7 - \
           4*r**6*s/7 + 6*r**5*s**6 - 123*r**5*s**5/7 + 130*r**5*s**4/7 - 33*r**5*s**3/7 - \
           6*r**5*s**2 + 26*r**5*s/7 - 15*r**4*s**6 + 330*r**4*s**5/7 - 58*r**4*s**4 + \
           187*r**4*s**3/7 + 64*r**4*s**2/7 - 10*r**4*s + 20*r**3*s**6 - 470*r**3*s**5/7 + \
           664*r**3*s**4/7 - 457*r**3*s**3/7 + 25*r**3*s**2/7 + 14*r**3*s - 15*r**2*s**6 + \
           375*r**2*s**5/7 - 601*r**2*s**4/7 + 564*r**2*s**3/7 - 215*r**2*s**2/7 - 78*r**2*s/7 + \
           6*r*s**6 - 159*r*s**5/7 + 286*r*s**4/7 - 346*r*s**3/7 + 261*r*s**2/7 - 24*r*s/7 - \
           s**6 + 4*s**5 - 8*s**4 + 12*s**3 - 101*s**2/7 + 52*s/7 + 1

def ex_bias2(r,s):
    return -2*r*s + 2*s + 1

def ex_bias3(r,s):
    return -2*r**2*s**2 + 2*r**2*s + 4*r*s**2 - 10*r*s - 2*s**2 + 8*s + 1

def ex_bias4(r,s):
    return 4*r**3*s**2 - 4*r**3*s - 16*r**2*s**2 + 16*r**2*s + 20*r*s**2 - 34*r*s - 8*s**2 + 22*s + 1

def ex_bias5(r,s):
    return 4*r**4*s**4 - 12*r**4*s**2 + 8*r**4*s - 16*r**3*s**4 + 8*r**3*s**3 + 48*r**3*s**2 -\
           40*r**3*s + 24*r**2*s**4 - 24*r**2*s**3 - 78*r**2*s**2 + 78*r**2*s - 16*r*s**4 + \
           24*r*s**3 + 60*r*s**2 - 98*r*s + 4*s**4 - 8*s**3 - 18*s**2 + 52*s + 1

def ex_bias6(r,s):
    return 8*r**5*s**5 - 16*r**5*s**4 - 8*r**5*s**3 + 32*r**5*s**2 - 16*r**5*s - 40*r**4*s**5 + \
           108*r**4*s**4 - 164*r**4*s**2 + 96*r**4*s + 80*r**3*s**5 - 272*r**3*s**4 + 112*r**3*s**3 + \
           316*r**3*s**2 - 236*r**3*s - 80*r**2*s**5 + 328*r**2*s**4 - 256*r**2*s**3 - 292*r**2*s**2 + \
           300*r**2*s + 40*r*s**5 - 192*r*s**4 + 216*r*s**3 + 132*r*s**2 - 258*r*s - 8*s**5 + 44*s**4 - \
           64*s**3 - 24*s**2 + 114*s + 1

def ex_bias7(r,s):
    return 8*r**6*s**6 - 40*r**6*s**5 + 40*r**6*s**4 + 40*r**6*s**3 - 80*r**6*s**2 + 32*r**6*s - \
           48*r**5*s**6 + 280*r**5*s**5 - 368*r**5*s**4 - 136*r**5*s**3 + 496*r**5*s**2 - 224*r**5*s + \
           120*r**4*s**6 - 800*r**4*s**5 + 1328*r**4*s**4 - 72*r**4*s**3 - 1240*r**4*s**2 + 664*r**4*s - \
           160*r**3*s**6 + 1200*r**3*s**5 - 2432*r**3*s**4 + 936*r**3*s**3 + 1528*r**3*s**2 - 1072*r**3*s + \
           120*r**2*s**6 - 1000*r**2*s**5 + 2408*r**2*s**4 - 1616*r**2*s**3 - 914*r**2*s**2 + 1002*r**2*s - \
           48*r*s**6 + 440*r*s**5 - 1232*r*s**4 + 1152*r*s**3 + 204*r*s**2 - 642*r*s + 8*s**6 - 80*s**5 + \
           256*s**4 - 304*s**3 + 6*s**2 + 240*s + 1

#### Contour plots as a funciton of $r$ and $s$

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype']='none'
plt.rcParams['text.usetex']=True
plt.rcParams['font.family']='serif'
rvals = np.linspace(0,1)
svals = np.linspace(0,1)
vvals1 = np.linspace(0, 1, 11, endpoint=True)
vvals2 = np.linspace(0.14, 0.56, 8, endpoint=True)
RR,SS = np.meshgrid(rvals,svals)
func_array = np.asarray([#[ex_bias2,ex_cdepth2],
                         [ex_bias3,ex_cdepth3],
                         [ex_bias4,ex_cdepth4],
                         [ex_bias5,ex_cdepth5],
                         [ex_bias6,ex_cdepth6],
                         [ex_bias7,ex_cdepth7],
                        ]).T
fig,ax_arr = plt.subplots(func_array.shape[0],func_array.shape[1],sharex=True,sharey=True,figsize=(6.5,3))
for ii,frow in enumerate(func_array):
    for jj,ff in enumerate(frow):
        #vm,vM = (0,1) #if jj==0 else ()
        BB = ff(RR,SS)
        if not isinstance(BB,np.ndarray):
            BB = np.ones(RR.shape)
        DENOM = 2**(jj+3) if ii==0 else (jj+3)
        vvals = vvals1 if ii==0 else vvals2
        cplot = ax_arr[ii,jj].contourf(RR,SS,BB/DENOM,vvals)
        if jj==0:
            plt.setp(ax_arr[ii,jj].get_yticklabels(),size=6)
        elif jj==func_array.shape[1]-1:
            lbl = 'Bias' if ii==0 else 'Avg. canalization depth'
            cbar=fig.colorbar(cplot,ax=ax_arr[ii,jj])
            cbar.set_label(lbl,size=7)
            plt.setp(cbar.ax.yaxis.get_ticklabels(),size=6)
        if ii==func_array.shape[0]-1:
            plt.setp(ax_arr[ii,jj].get_xticklabels(),size=6)
        elif ii==0:
            ax_arr[ii,jj].set_title('N=%d' % (jj+3),size=8)

for ii in range(func_array.shape[0]):
    ax_arr[ii,0].set_ylabel(r'$s$: P($y_i+y_{i+1}$)',size=7)
for ii in range(func_array.shape[1]):
    ax_arr[-1,ii].set_xlabel(r'$r$: P($y_i\times(y_{i+1}$ )',size=7)
    # if ii==0:
    #     ax_arr[0,ii].set_title('Bias',size=8)
    # else:
    #     ax_arr[0,ii].set_title('Avg. canalization depth',size=8)
#cb = plt.colorbar()
#cb.set_label('Avg. canalization depth')
fig.savefig('figs/network_contours.svg')

## Fig. S1: 
Convergence of irreversibility estimated from independent sets of realizations.

_These cells require the script_ `analyze_replicates.py` _to be run_.

In [None]:
ovr_out_d = pd.read_pickle('tmp/overall_output_d.pkl')
keyDF = pd.DataFrame(ovr_out_d.keys(),columns=['s','r','srt','N','quant','stat'])
statGB = keyDF.groupby(['s','r','srt'])
def ss2bxp(ss):
    # med: Median (scalar).
    # q1, q3: First & third quartiles (scalars).
    # whislo, whishi: Lower & upper whisker positions (scalars).
    # mean: Mean (scalar). Needed if showmeans=True.    
    bxp_stat_d = {}
    bxp_stat_d['med'] = ss.loc['50%']
    bxp_stat_d['q1'] = ss.loc['25%']
    bxp_stat_d['q3'] = ss.loc['75%']
    bxp_stat_d['whislo'] = ss.loc['min']
    bxp_stat_d['whishi'] = ss.loc['max']
    bxp_stat_d['mean'] = ss.loc['mean']
    return bxp_stat_d


### Fig. S1A:
Ascending results

In [None]:
ncol = 4
plt.rcParams['svg.fonttype']='none'
plt.rcParams['text.usetex']=False
plt.rcParams['font.family']='serif'

fig,ax_arr = plt.subplots(3,ncol,figsize=(7.5,5),dpi=200,sharey=True,sharex=True)
sel_srt = 'asc'
fig_suptitle = 'Input sorting: Ascending'
for (ss,rr,srt),grp in statGB:
    if srt!=sel_srt:
        continue
    if ss=='1.00':
        ax_row = 2
        ax_col = int((float(rr)+1e-12)/0.20)-1
    elif float(ss) + float(rr)==1:
        ax_row = 1
        ax_col = int((float(rr)+1e-12)/0.20)-1
    else:
        ax_row = 0
        ax_col = ncol-int((float(ss)+1e-12)/0.20)
        ax_l = [ax_arr[ax_row,ax_col]]
    for quant,grp2 in grp.groupby('quant'):
        if quant !='irrev':
            continue
        for stat,grp3 in grp2.groupby('stat'):
            if stat != 'rmsd':
                continue
            else:
                ylbl = 'Irreversibility RMSD'
            dod = {}
            for nn,row in grp3.iterrows():
                kk = tuple(row.values)
                dod[row.N] = ss2bxp(ovr_out_d[kk])
            bxp_l = [dod[ii] for ii in range(1,11)]
            pos_l = range(1,11)
            #for axnum,ax in enumerate(ax_l):
            ax = ax_arr[ax_row,ax_col]
            ax.bxp(bxp_l,pos_l,0.7,showmeans=False,showfliers=False)
            if (ax_row==2):
                ax.set_xlabel('Realizations in set',size=8)
                plt.setp(ax.get_xticklabels(),size=6)
            if ax_col==0:
                ax.set_ylabel(ylbl,size=8)
                plt.setp(ax.get_yticklabels(),size=6)
            ax.set_title(f's:{ss} r:{rr}',size=6)
            
fig.suptitle(fig_suptitle)
fig.savefig('figs/rmsd_convergence_ascending.svg')

### Fig. S1B
Descending results

In [None]:
fig,ax_arr = plt.subplots(3,ncol,figsize=(7.5,5),dpi=200,sharey=True,sharex=True)
sel_srt = 'desc'
fig_suptitle = 'Input sorting: Descending'

for (ss,rr,srt),grp in statGB:
    if srt!=sel_srt:
        continue
    if ss=='1.00':
        ax_row = 2
        ax_col = int((float(rr)+1e-12)/0.20)-1
    elif float(ss) + float(rr)==1:
        ax_row = 1
        ax_col = int((float(rr)+1e-12)/0.20)-1
    else:
        ax_row = 0
        ax_col = ncol-int((float(ss)+1e-12)/0.20)
        ax_l = [ax_arr[ax_row,ax_col]]
    for quant,grp2 in grp.groupby('quant'):
        if quant !='irrev':
            continue
        for stat,grp3 in grp2.groupby('stat'):
            if stat != 'rmsd':
                continue
            else:
                ylbl = 'Irreversibility RMSD'
            dod = {}
            for nn,row in grp3.iterrows():
                kk = tuple(row.values)
                dod[row.N] = ss2bxp(ovr_out_d[kk])
            bxp_l = [dod[ii] for ii in range(1,11)]
            pos_l = range(1,11)
            #for axnum,ax in enumerate(ax_l):
            ax = ax_arr[ax_row,ax_col]
            ax.bxp(bxp_l,pos_l,0.7,showmeans=False,showfliers=False)
            if (ax_row==2):
                ax.set_xlabel('Realizations in set',size=8)
                plt.setp(ax.get_xticklabels(),size=6)
            if ax_col==0:
                ax.set_ylabel(ylbl,size=8)
                plt.setp(ax.get_yticklabels(),size=6)
            ax.set_title(f's:{ss} r:{rr}',size=6)

fig.suptitle(fig_suptitle)
fig.savefig('figs/rmsd_convergence_descending.svg')