In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import roman
from statsmodels.robust.scale import mad
import seaborn as sns

OLIGO_TABLE_PATH = '../data_tables/oligos_nonuniq_crispey3_GG_9bp_OLIGO_with_seq_primers.txt'
#gxe_df =pd.read_csv('gxe_interaction_5cond_bc_filter_added_6repscm.tsv',sep = '\t')
#var_fitness_df =pd.read_csv('gxe_fitness_5cond_bcfilter_added_6repscm.tsv',sep = '\t')
oli_info = pd.read_csv(OLIGO_TABLE_PATH,'\t')


annotations = pd.read_csv('../data_tables/ergosterol_annotations.tsv',sep = '\t')
oli_info['ALT'] = oli_info['ALT'].str.split('[').str[1].str.split(']').str[0]
annos = annotations.merge(oli_info,'inner',['var_id','chrom','SNP_chr_pos','ALT'])
annos  = annos[['var_id','chrom','SNP_chr_pos','Gene','Consequence','CDS_position','ALT','REF', 'chromosome']].drop_duplicates()

In [None]:
def robust_outlier_removal(umi_df, fc_cutoff=3.5, basemean_cutoff=5):
    '''
    accepts a dataframe of deseq results and removes outliers based on log2FoldChange or baseMean values
    filters log2FoldChange by robust MADs - shrinks variance in logFC values
    filters baseMean by simple X-fold cutoff above median - reduces dependence on high leverage barcodes
    adjust cutoffs as necessary
    '''
    if len(umi_df)==1:
        return pd.DataFrame(index=umi_df.index)
    
    indices = umi_df.index.values
    fcs = umi_df['log2FoldChange'].values
    fc_outlier_stats = [abs(i-np.median(fcs))/mad(fcs) for i in fcs]
    outliers = []
    for i in range(len(indices)):
        if fc_outlier_stats[i]>fc_cutoff:
            outliers.append(indices[i])
    
    return umi_df.drop(outliers)

In [None]:
lova = pd.read_csv('../data_tables/Deseq_outputs/deseq2_res_bar_P1_LOV_competitiontime_SCM_umi_level.tsv', sep = '\t')
nacl = pd.read_csv('../data_tables/Deseq_outputs/deseq2_res_bar_P1_NACL_competitiontime_SCM_umi_level.tsv', sep = '\t')
caff = pd.read_csv('../data_tables/Deseq_outputs/deseq2_res_bar_P1_CAFF_competitiontime_SCM_umi_level.tsv', sep = '\t')
sc = pd.read_csv('../data_tables/Deseq_outputs/deseq2_res_bar_P1_SCM_competitiontime_SCM_umi_level.tsv', sep = '\t')
cocl = pd.read_csv('../data_tables/Deseq_outputs/deseq2_res_bar_P1_COCL_competitiontime_SCM_umi_level.tsv', sep = '\t')
tbf = pd.read_csv('../data_tables/Deseq_outputs/deseq2_res_bar_P1_TBF_competitiontime_SCM_umi_level.tsv', sep = '\t')


lova['barcode_id'] =lova.index.str.split('_').str[1:3].str.join('_')
nacl['barcode_id'] =nacl.index.str.split('_').str[1:3].str.join('_')
caff['barcode_id'] =caff.index.str.split('_').str[1:3].str.join('_')
sc['barcode_id'] =sc.index.str.split('_').str[1:3].str.join('_')
cocl['barcode_id'] =cocl.index.str.split('_').str[1:3].str.join('_')
tbf['barcode_id'] = tbf.index.str.split('_').str[1:3].str.join('_')

def remove_outliers(df,title=None):
    fig, ax = plt.subplots()
    df.plot.scatter('baseMean', 'log2FoldChange', c = 'r', ax = ax)
    df2 = df.groupby('barcode_id').apply(robust_outlier_removal).droplevel('barcode_id')
    df2.plot.scatter('baseMean', 'log2FoldChange', c = 'black', ax = ax)
    ax.set_xscale('log')
    ax.set_title(title)
    return df2

lova_filt = remove_outliers(lova, 'LOV')
cocl_filt = remove_outliers(cocl, 'COCL')
caff_filt = remove_outliers(caff,'CAFF')
nacl_filt = remove_outliers(nacl, 'NACL')
sc_filt = remove_outliers(sc, 'SC')
tbf_filt = remove_outliers(tbf, 'TBF')

lova_inf = lova_filt.merge(oli_info, 'inner','barcode_id')
cocl_inf = cocl_filt.merge(oli_info, 'inner','barcode_id')
nacl_inf = nacl_filt.merge(oli_info, 'inner','barcode_id')
caff_inf = caff_filt.merge(oli_info, 'inner','barcode_id')
sc_inf = sc_filt.merge(oli_info, 'inner','barcode_id')
tbf_inf = tbf_filt.merge(oli_info, 'inner','barcode_id')

In [None]:

def bc_filter(df, cutoff):
    df_stds = df.groupby('barcode_id')['log2FoldChange'].std()
    good_bcs = df_stds[df_stds<cutoff].index
    return df[df['barcode_id'].isin(good_bcs)]


lova_inf =bc_filter(lova_inf,0.05)
cocl_inf = bc_filter(cocl_inf,0.05)
nacl_inf = bc_filter(nacl_inf,0.05)
caff_inf = bc_filter(caff_inf, 0.05)
sc_inf = bc_filter(sc_inf, 0.05)
tbf_inf = bc_filter(tbf_inf,0.05)


In [None]:
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 7
plt.rcParams["axes.titlesize"] = 7
plt.rcParams["axes.labelsize"] = 7

def var_plotter(var):
    cocl_oli = cocl_inf.merge(oli_info)
    lova_oli =  lova_inf.merge(oli_info)
    nacl_oli = nacl_inf.merge(oli_info)
    caff_oli = caff_inf.merge(oli_info)
    scm_oli = sc_inf.merge(oli_info)
    tbf_oli = tbf_inf.merge(oli_info)
    ref = oli_info[oli_info['var_id']==var]['REF'].values[0]
    alt = oli_info[oli_info['var_id']==var]['ALT'].values[0]
    chrom = oli_info[oli_info['var_id']==var]['chrom'].values[0]
    pos = oli_info[oli_info['var_id']==var]['SNP_chr_pos'].values[0].astype(str)
    tbfv = tbf_oli[tbf_oli['var_id']==var][['log2FoldChange', 'barcode_id']]
    tbfv['Condition'] = 'Terbinafine'
    caffv = caff_oli[caff_oli['var_id']==var][['log2FoldChange', 'barcode_id']]
    caffv['Condition'] = 'Caffeine'
    naclv = nacl_oli[nacl_oli['var_id']==var][['log2FoldChange', 'barcode_id']]
    naclv['Condition'] = 'NaCl'
    coclv = cocl_oli[cocl_oli['var_id']==var][['log2FoldChange', 'barcode_id']]
    coclv['Condition'] = 'CoCl2'
    lovav = lova_oli[lova_oli['var_id']==var][['log2FoldChange', 'barcode_id']]
    lovav['Condition'] = 'Lovastatin'
    scmv = scm_oli[scm_oli['var_id']==var][['log2FoldChange', 'barcode_id']]
    scmv['Condition'] = 'SCM'
    sns.swarmplot(x = 'Condition', y = 'log2FoldChange',data = pd.concat([scmv, caffv,lovav,tbfv, naclv, coclv]), color = 'black', s = 2)
    plt.ylabel('Relative Fitness')
    plt.axhline(0,linestyle = 'dashed', color = 'black', linewidth = .5)
    plt.title('chr'+str(roman.fromRoman(chrom))+':'+str(pos)+ref+'>'+alt)

In [None]:
fig,ax = plt.subplots(figsize = (2,1.5))

var_plotter('EGD_00098');
plt.xticks(rotation = 45);
plt.xlabel(None);
plt.yticks([.025,0,-.025,-.05,-.075,-0.10]);
plt.savefig('../GxE_Figures/Figure_5/fig5a_onecondition.svg')

In [None]:
fig,ax = plt.subplots(figsize = (2,1.5))

var_plotter('ERG_00397')
plt.xticks(rotation = 45)
plt.xlabel(None)
plt.yticks([-.025,0,.025,.05])
plt.savefig('../GxE_Figures/Figure_5/fig5b_onedirection.svg')

In [None]:
fig,ax =plt.subplots(figsize = (2,1.5))

var_plotter('ERG_00086')
plt.xticks(rotation = 45)
plt.xlabel(None)
plt.yticks([.025,0,-.025,-.05,-.075,-.10])
plt.savefig('../GxE_Figures/Figure_5/fig5c_signGxE.svg')