In [1]:
import numpy as np
import pandas as pd
import networkx as nx
from glob import glob
from networkx.readwrite.graphml import read_graphml
from networkx.algorithms.shortest_paths.generic import shortest_path_length
import matplotlib.pyplot as plt
import patsy
import numpy as np
from scipy.stats import scoreatpercentile as SAP
from collections import defaultdict
from sklearn.metrics import roc_curve,RocCurveDisplay

plt.rcParams['svg.fonttype']='none'

## Preprocessing of the RNA-Seq data

In [2]:
data = pd.read_table("input_files/GSE152214_counts_new_GSM5227097-GSM5227106.txt.gz",index_col=1,sep=" ")
data2 = pd.read_table("input_files/GSE152214_Raw_counts.txt.gz",index_col=0,sep="\t")

TPM_l = []
for ctr,ds in enumerate([data,data2]):
    dstrt = 5 if ctr==0 else 4
    COUNTS = ds.iloc[:,dstrt:]
    META = ds.iloc[:,:dstrt]
    RPK = COUNTS.T/(META.Length/1000)
    TC = RPK.T.sum(axis=0)
    TPM = RPK.T/TC * 1e6
    TPM_l.append(TPM)
TPM_dat = pd.concat(TPM_l,axis=1)

from collections import defaultdict
my_d = defaultdict(dict)
for col in TPM_dat.columns:
    parts = col.split('_')
    if len(parts)==2:
        my_d[col]['strain']=parts[0].lower()
        my_d[col]['cultivation']='batch'
    elif len(parts)==3:
        if 'chemostat' in col:
            my_d[col]['strain']=parts[0].lower()
            my_d[col]['cultivation']='chemostat'
        else:
            my_d[col]['strain']='+'.join(parts[:2]).lower()
            my_d[col]['cultivation']='batch'


logTPM_dat = (np.log10(TPM_dat+1e-3)+3)

### Batch correct data to eliminate covariates

In [3]:
exp_meta = pd.DataFrame(dict(my_d)).T
n_array = logTPM_dat.shape[0]


mod = patsy.dmatrix("~0+strain",exp_meta,return_type='dataframe')
mod2 = patsy.dmatrix("~0 + cultivation",exp_meta,return_type='dataframe')
design = pd.concat([mod,mod2.loc[:,['cultivation[chemostat]']]],axis=1)
U,S,V = np.linalg.svd(design,full_matrices=False)
des_inv = np.dot(V.T*1/S**2,V)
des_inv_df = pd.DataFrame(des_inv,design.columns,design.columns)
B_hat = des_inv_df@design.T@logTPM_dat.T
var_pooled = (np.square(logTPM_dat-(design@B_hat).T) ).mean(axis=1)


### Gather the irreversibile response gene data

In [None]:
import subprocess

crp_res_l = glob('results/result-irrgn-OE-twoparam*.csv')
irr_resp_gn_d = {}
crp_freq_d = {}
crp_joint_d = {}
for fn in crp_res_l:
    tail = fn.split('/')[-1]
    itail,ss,rr,srt,rep,__ = tail.split('_')
    pt = itail.split('-')[2]
    att_fn = 'attfiles/1st_twoparam_%s_%s_%s_%s_%d.csv' % (ss,rr,srt,rep,int(rep))
    with open(fn) as fh:
        freq_dd = defaultdict(int)
        gene_dd = defaultdict(list)
        att_l = []
        for ln in fh:
            att_num,gn_l = ln.strip().split('\t')
            att_l.append(att_num)
            for gn in gn_l.strip().split(', '):
                freq_dd[gn]+=1
                gene_dd[gn].append(att_num)
    attr_str = subprocess.check_output(["wc", "-l", att_fn])
    tot_attr = int(attr_str.decode('utf8').strip().split(' ')[0])
    crp_freq = len(att_l)/tot_attr
    freq_ser = pd.Series(dict(freq_dd))/len(att_l)
    irr_resp_gn_d[(pt,ss,rr,srt,rep)] = freq_ser
    crp_freq_d[(pt,ss,rr,srt,rep)] = crp_freq
    gene_d = dict(gene_dd)
    crp_jfd = {}
    for gn1 in sorted(gene_d.keys()):
        S1 = set(gene_d[gn1])
        for gn2 in sorted(gene_d.keys()):
            if gn2 > gn1:
                S2 = set(gene_d[gn2])
                crp_jfd[(gn1,gn2)] = len(S1&S2)/tot_attr /freq_ser.loc[gn2] /freq_ser.loc[gn1]
    crp_jf_ser = pd.Series(crp_jfd)
    crp_joint_d[(pt,ss,rr,srt,rep)]=crp_jf_ser
irr_resp_gn_df = pd.DataFrame(irr_resp_gn_d)
crp_freq_ser = pd.Series(crp_freq_d)

In [None]:
irg_cmb_freq = (irr_resp_gn_df*crp_freq_ser)
freq_irrev_df = irg_cmb_freq.fillna(0).groupby(level=[3],axis=1).mean()
freq_irrev_df.to_pickle('ensavg_irr_resp_gn_df.pkl')


In [4]:
def load_graph_stats(src='n23'):
    """Obtains the graph statistics relative to the crp gene."""
    G = read_graphml('results/rs2_irr_neg2.graphml')
    nd_lbl_d = {}
    lbl_nd_d = {}
    for nd in G.nodes(): 
        nd_lbl_d[nd]= G.nodes[nd]['label']
        lbl_nd_d[G.nodes[nd]['label']]= nd
    shrt_paths = pd.Series(dict([(nd_lbl_d[kk],vv) for kk,vv in shortest_path_length(G,src).items()]))
    in_degrees = pd.Series(dict([(nd_lbl_d[kk],vv) for kk,vv in G.in_degree()]))
    crp_signs = pd.Series(dict([(nd_lbl_d[nd],G.edges[src,nd]['weight']) for nd in G.successors(src)]))
    auto_reg_d = {}
    for nd in G.nodes():
        try:
            auto_reg_d[nd_lbl_d[nd]]=G.edges[nd,nd]['weight']
        except KeyError:
            auto_reg_d[nd_lbl_d[nd]]=0
    auto_reg = pd.Series(auto_reg_d)
    DF = pd.concat([shrt_paths, in_degrees, crp_signs, auto_reg],axis=1)
    DF = DF.fillna({0:np.inf,1:np.nan,2:0,3:np.nan,})
    DF.columns = ['shrt_path','in_degree','crp_signs','auto_reg']
    return DF,G,nd_lbl_d,lbl_nd_d

graph_stats,G,nd_lbl_d,lbl_nd_d = load_graph_stats()


## Statistical analysis of the RNA-Seq data

In [5]:
DF = pd.concat([freq_irrev_df, graph_stats.shrt_path, graph_stats.in_degree, graph_stats.crp_signs, graph_stats.auto_reg],axis=1)
DF = DF.fillna({'asc':0, 'desc':0, 0:np.inf, 1:0, 2:np.nan, 3:0})
DF.columns = [('','','irr_prob_asc'),('','','irr_prob_desc'),('','','shrt_path'),('','','in_degree'),('','','crp_signs'),('','','auto_reg')]
DF.columns = pd.MultiIndex.from_tuples(DF.columns)

state_cols_d = {('wt','batch'):['WT_S1','WT_S2'],
                ('crp','batch'):['CRP_S1','CRP_S2'],
                ('evo1','batch'):['Evocrp1_S1','Evocrp1_S2'],
                ('evo2','batch'):['Evocrp2_S1','Evocrp2_S2'],
                ('evo3','batch'):['Evocrp3_S1','Evocrp3_S2'],
                ('evo4','batch'):['Evocrp4_S1','Evocrp4_S2'],
                ('evo5','batch'):['Evocrp5_S1','Evocrp5_S2'],
                ('mut','batch'):['IG116_crp_S1','IG116_crp_S2'],
                ('wt','chemo'):['WT_chemostat_S1','WT_chemostat_S2'],
                ('crp','chemo'):['CRP_chemostat_S1','CRP_chemostat_S2'],
                ('evo1','chemo'):['EvoCrp1_chemostat_S1','EvoCrp1_chemostat_S2'],
                ('evo3','chemo'):['EvoCrp3_chemostat_S1','EvoCrp3_chemostat_S2'],
               }

statepair_cols_l = [
                ('crp','wt','batch'), ('crp','wt','chemo'),
                ('evo1','wt','chemo'), ('evo3','wt','chemo'),
                ('evo1','wt','batch'), ('evo2','wt','batch'),
                ('evo3','wt','batch'), ('evo4','wt','batch'),
                ('evo5','wt','batch'), ('mut','wt','batch')]

states_d ={}
avgexp_d={}
stdexp_d={}
cvexp_d = {}
from scipy.stats import scoreatpercentile as SAP

for (st,lbl),cols in state_cols_d.items():
    states_d[(st,lbl)]= TPM_dat.loc[:,cols]
    MUS = TPM_dat.loc[:,cols].mean(axis=1)
    avgexp_d[(st,'mean',lbl)] = MUS
    SIGS = TPM_dat.loc[:,cols].std(axis=1)
    CV = SIGS/MUS
    lbcv,ubcv = SAP((CV[MUS>1]),[5,95])
    LB = MUS*lbcv
    UB = MUS*ubcv
    stdexp_d[(st,'sig',lbl)] = SIGS.clip(LB,UB)
    cvexp_d[(st,'cv',lbl)] = SIGS.clip(LB,UB)/MUS
    #print(st,lbl,lbcv,ubcv)

    
avgexp_df = pd.DataFrame(avgexp_d)
stdexp_df = pd.DataFrame(stdexp_d)
cvexp_df = pd.DataFrame(cvexp_d)

zv_d = {}
for st_f,st_i,ccond in statepair_cols_l:
    fin = states_d[(st_f,ccond)]
    init = states_d[(st_i,ccond)]
    pooled_sig = np.sqrt(pd.concat([(fin.T-fin.mean(axis=1))**2,(init.T-init.mean(axis=1))**2]).sum()/2)
    fc = fin.mean(axis=1)/init.mean(axis=1)
    fc_sig = np.sqrt((cvexp_df.loc[:,(st_i,'cv',ccond)])**2 + (cvexp_df.loc[:,(st_f,'cv',ccond)])**2)
    zv_d[(st_f,st_i,ccond+'_fc')] = fc
    zv_d[(st_f,st_i,ccond+'_fcsig')] = fc_sig
zv_df = pd.DataFrame(zv_d)

combined_table = pd.concat([DF,zv_df,avgexp_df,stdexp_df,cvexp_df],axis=1)
combined_table=combined_table.loc[DF.index]

C1 = (combined_table.xs('shrt_path',level=2,axis=1)<2).iloc[:,0] 
C2 = (combined_table.xs('in_degree',level=2,axis=1)<3).iloc[:,0]
CA = (combined_table.xs(('wt','batch'),level=[0,2],axis=1)>1).iloc[:,0]
CB = (combined_table.xs(('wt','chemo'),level=[0,2],axis=1)>1).iloc[:,0]

In [None]:
ROC_data_d2 = {}

## do we want to exclude genes that are unreachable OR those that have zero irreversibility?
for zz,isBatch in enumerate([True,False]):    
    ccond = 'batch' if isBatch else 'chemo'
    SEL = combined_table.loc[[elt for elt in combined_table.index if elt!='crp']]
    zz_vals = ['evo%d' % ii for ii in range(1,6)]+['mut'] if isBatch else ['evo1','evo3']
    yvals0 = SEL.loc[:,[(num,'wt','%s_fc' % ccond) for num in zz_vals]].mean(axis=1)
    MU = np.mean(yvals0)
    logFC = np.log((yvals0/MU))
    TF_d = {}
    for gn,fc in logFC.items():
        ## note that the signs must be opposite because the perturbation is a knockdown
        if np.sign(fc)<0 and SEL.loc[gn,('','','crp_signs')] in ['+','pm']:
            TF_d[gn]=True
        elif np.sign(fc)>0 and SEL.loc[gn,('','','crp_signs')] in ['-','pm']:
            TF_d[gn]=True
        else:
            TF_d[gn]=False
    TF = pd.Series(TF_d)
    comb_df = pd.concat([np.abs(logFC),TF],axis=1)
    comb_df.columns = ['expr_chng','data_match']
    ROC_data_d2[ccond] = comb_df

ROC_data_d2 = dict(ROC_data_d2)


In [None]:
from copy import copy
import matplotlib.pyplot as plt
from sklearn.metrics import PrecisionRecallDisplay
plt.rcParams['svg.fonttype']='none'
plt.rcParams['text.usetex']=False
plt.rcParams['font.family']='sans'


fig,ax = plt.subplots(1,1,figsize=(3.375,3),dpi=180)
cond_labels=['Batch (AP = 0.99)','Chemostat (AP = 0.85)']
for jj,cond in enumerate(['batch','chemo']):
    dat=ROC_data_d2[cond]
    sel_dat = dat[dat.expr_chng>0.5]
    INDS = sel_dat.expr_chng.sort_values().index[::-1]
    OP = PrecisionRecallDisplay.from_predictions(dat.data_match,dat.expr_chng,pos_label=True,ax=ax,color=f'C{jj}')
    ax.axvline(OP.recall[-1*(len(INDS)+1)],ls=':',color=f'C{jj}')
ax.legend().set_visible(False)
hndls,lbls = ax.get_legend_handles_labels()
ax.legend(hndls,cond_labels,frameon=False,prop={'size':6})
ax.set_ylabel('Precision')
ax.set_xlabel('Recall')
plt.setp(ax.get_xticklabels(),size=8)
plt.setp(ax.get_yticklabels(),size=8)
fig.savefig('figs/crp_precision_recall_curve.svg')

#### Calculate the shortest paths to crp and the sign of the paths

In [None]:
crp_nd = [nd for nd in G.nodes() if nd_lbl_d[nd]=='crp'][0]
nd_data_d = defaultdict(dict)
nd_pathl_d = {}
for nd in G.nodes():
    if nd==crp_nd:
        continue
    nd_data_d[nd]['name']=nd_lbl_d[nd]
    try:
        asp_l = list(nx.all_shortest_paths(G,crp_nd,nd))
        nd_data_d[nd]['crp_distance']=len(asp_l[0])
        nd_pathl_d[nd]=asp_l
    except nx.NetworkXNoPath:
        asp_l = []
        nd_data_d[nd]['crp_distance']=np.nan
    wt_l = []
    for sp in asp_l:
        wt = 1
        for (u,v) in zip(sp[:-1],sp[1:]):
            wt*=G.edges[u,v]['weight']
        wt_l.append(wt)
    if len(wt_l)>1:
        nd_data_d[nd]['sign']='both'
    elif len(wt_l)>0:
        nd_data_d[nd]['sign']='positive' if wt_l[0]>0 else 'negative'
    else:
        nd_data_d[nd]['sign']='none'
    
crp_data_df = pd.DataFrame(dict(nd_data_d)).T
crp_data_df

#### Significance of the relationship between the adaptive evolution responses and the irreversible genes

In [None]:
avg_irg_prob = combined_table.loc[:,[('','','irr_prob_asc'),('','','irr_prob_desc')]].sum(axis=1)>0
aaa = avg_irg_prob[avg_irg_prob.index!='crp']
full_table = pd.concat([pd.concat(ROC_data_d2,axis=1),aaa],axis=1)
cols = full_table.columns.tolist()
cols[-1]=('','IsIrrev')
full_table.columns = pd.MultiIndex.from_tuples(cols)
thr = 0.5
N_samp = 25000

act_cond_diff_d = {}
for cond in ['batch','chemo']:
    cond_table = full_table.xs(cond,level=0,axis=1)
    expr_sel = cond_table.expr_chng>0.5
    irg_mean = cond_table[(isIRG)&(expr_sel)].data_match.sum()
    nonirg_mean = cond_table[(~isIRG)&(expr_sel)].data_match.sum()
    act_cond_diff_d[cond]=irg_mean-nonirg_mean

rnd_cond_diff_d = {}
act_batch = full_table.xs('batch',level=0,axis=1)
act_chemo = full_table.xs('chemo',level=0,axis=1)
isIRG = full_table.loc[:,('','IsIrrev')]
for ii in range(N_samp):
    perm_ind = np.random.permutation(full_table.index)
    new_batch = pd.DataFrame(act_batch.values,index=perm_ind,columns=act_batch.columns)
    perm_ind2 = np.random.permutation(full_table.index)
    new_chemo = pd.DataFrame(act_chemo.values,index=perm_ind2,columns=act_chemo.columns) 
    for cond,rnd_data in [('batch',new_batch),('chemo',new_chemo)]:
        expr_sel = rnd_data.expr_chng>0.5
        AAA = ((isIRG)&(expr_sel)).loc[rnd_data.index]
        BBB = ((~isIRG)&(expr_sel)).loc[rnd_data.index]
        #print(ii,cond,AAA.sum(),BBB.sum())
        irg_mean = rnd_data[AAA].data_match.sum()
        nonirg_mean = rnd_data[BBB].data_match.sum()
        rnd_cond_diff_d[(ii,cond)]=irg_mean-nonirg_mean
## boostrapped significance
nn_l = []
for nn,row in pd.Series(rnd_cond_diff_d).unstack().iterrows():
    if row.loc['batch']>act_cond_diff_d['batch'] and row.loc['chemo']>act_cond_diff_d['chemo']:
        nn_l.append(nn)
print(len(nn_l)/25000)

#### Significance of the average precision score in each condition

In [None]:
for COND in ['batch','chemo']:
    dat=ROC_data_d2[COND]
    sel_dat = dat[dat.expr_chng>thr]
    INDS = sel_dat.expr_chng.sort_values().index[::-1]
    act = average_precision_score(sel_dat.loc[INDS].data_match,sel_dat.loc[INDS].expr_chng)
    vals2 = dat.sort_values('expr_chng').values[::-1,1]
    vals1 = dat.data_match.values
    ap_l = []
    for ii in range(N_samp):
        new_vals1 = np.random.permutation(vals1)
        ap = average_precision_score(new_vals1[:len(INDS)],vals2[:len(INDS)])
        ap_l.append(ap)
    print(COND,np.mean(np.asarray(ap_l)>act),act)
