In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
from matplotlib import pyplot as pl
import numpy as np
import seaborn as sns
from fitness_assay_returns_more import inferFitness, inverseVarAve
import pandas as pd
from collections import defaultdict
from milo_tools import reverse_transcribe

# including the sequence between the barcodes to get the max AT run of the whole region (which is flanked by CG seqs)
middle_seq = 'ATAACTTCGTATAATGTATGCTATACGAAGTTAT'

def gc(s):
    return len([i for i in s if i in ['G', 'C']])
    
def sliding_window_min(row, win_size):
    s = row['Diverse.BC'] + middle_seq + reverse_transcribe(row['Environment.BC'])
    return min([gc(s[i:i+win_size]) for i in range(len(s)-win_size+1)])
            
def inverseVarAve_w_nan(meanVals,standardDevs):
    """
    inverseVarAve - take weighted average with inverse variances.
    :param meanVals: Values to be averaged, N x q. Averaged across second dimension
    :param standardDevs: Standard errors of each value, N x q
    :return weightedMeans: N x 1 vector of weighted average
    :return weightedStandardDevs: N x 1 vector of final standard error
    """

    weightedMeans = np.nansum(meanVals*np.power(standardDevs,-2),axis=1)/np.nansum(np.power(standardDevs,-2),axis=1)
    weightedStandardDevs = np.power(np.nansum(np.power(standardDevs,-2),axis=1),-0.5)

    return weightedMeans,weightedStandardDevs

def measure_fitness(td, putative_neutral_bcs, tp_ex, bfa_name):
    fit_data = dict()
    c_times = [1, 2, 3, 4, 5]
    td.sort_values(by='Full.BC', inplace=True)
    bcs = list(td['Full.BC'])
    envs = set([i.split('-')[1] for i in td.columns if 'Time' in i and i.split('-')[1] not in ['Pre', 'T0_Pool']])
    for env in envs:
        read_dat = dict()
        reps = sorted(set([i.split('-')[2] for i in td.columns if 'Time' in i and env in i]))
        for rep in reps:
            if not 'EXCLUDE ALL' in tp_ex[env][rep]:
                excluded_tps = tp_ex[env][rep]
                tps = [bfa_name + '-' + env + '-' + rep + '-Time' + str(i*8) for i in c_times]
                # to exclude timepoints I will just zero out the counts so they will be caught by the low coverage thresh
                for tp in tps:
                    if tp[tp.index('Time')+4:] in excluded_tps:
                        td[tp] = np.zeros(len(td))
                tmp_read_dat = np.nan_to_num(td.as_matrix([bfa_name + '-' + env + '-' + rep + '-Time' + str(i*8) for i in c_times]))
                included_tps = [i for i in c_times if np.sum(tmp_read_dat, axis=0)[i-1] > 1e5]
                if len(included_tps) < 2:
                    pass # print(bfa_name, env, rep, 'not enough tps')
                else:
                    read_dat[rep] = tmp_read_dat
                    #print(bfa_name, env, rep, 'included tps:', included_tps)
        if len(read_dat) > 0:
            fit_data[env], used_neuts = inferFitness(bcs, c_times, read_dat, outputFolder = 'test_Atish_out/', 
                                              experimentName = bfa_name+'-'+env+'-', lowCoverageThresh=1e5, 
                                              neutralBarcodes=putative_neutral_bcs)
            # averaging across replicates and adding to fit_data
            tmp = fit_data[env]
            reps = [i for i in tmp.keys() if 'R' in i]
            s_aves = np.array([tmp[r]['aveFitness'] for r in reps]).T
            s_errs = np.array([tmp[r]['aveError'] for r in reps]).T
            td[env + '-iva_s'], td[env + '-iva_s_err'] = inverseVarAve_w_nan(s_aves, s_errs)
            td[env + '-ave_s'] = np.nanmean(s_aves, axis=1)
            td[env + '-ave_err'] = np.power(np.mean(np.power(s_errs, 2),axis=1), 0.5)
            for r in reps:
                td[env + '-' + r + '-aveFitness'] = tmp[r]['aveFitness']
                td[env + '-' + r + '-aveError'] = tmp[r]['aveError']
            td['used_as_neutral_in_' + env] = used_neuts
            
def plot_s_corr(td, bfa_name, output_base, file_end):
    envs = set([i.split('-')[0] for i in td.columns if '-iva_s' in i])
    for env in envs:
        reps = [i.split('-')[1] for i in td.columns if env+'-R' in i and '-aveFitness' in i]
        if len(reps) > 1:
            fig, subps = pl.subplots(len(reps), len(reps), figsize=(5*len(reps), 5*len(reps)), sharex=True, sharey=True)
            mn = max([i for i in td[env + '-iva_s']/8 if i < 10])
            mx = min([i for i in td[env + '-iva_s']/8 if i > -10])
            for i in range(len(reps)):
                for j in range(len(reps)):
                    subps[i][j].plot([mn, mx], [mn, mx], linestyle='dashed', color='k', alpha=0.5)
                    if i == j:
                        subps[i][j].scatter(td[env + '-' + reps[i] + '-aveFitness']/8, td[env + '-iva_s']/8)
                        subps[i][j].set_xlabel(reps[i])
                        subps[i][j].set_ylabel('replicate average')
                    else:
                        subps[i][j].scatter(td[env + '-' + reps[i] + '-aveFitness']/8, td[env + '-' + reps[j] + '-aveFitness']/8)
                        subps[i][j].set_xlabel(reps[i])
                        subps[i][j].set_ylabel(reps[j])


            fig.savefig(output_base + bfa_name + '-' + env + '_iva' + file_end)
            pl.close("all") 

            
def column_is_s_related(c):
    for s in ['-iva_s', 'iva_s_err', '-ave_s', '-ave_err', '-aveFitness', '-aveError']:
        if s in c:
            return True
    return False

info_cols = ['Diverse.BC', 'Environment.BC', 'Full.BC', 'Subpool.Environment', 'Which.Subpools']
            
dbfa2 = pd.read_csv('../Final_Count_Pipeline/BFA_data/Combined_Counts/dBFA2_counts_with_env_info.csv')
hbfa1 = pd.read_csv('../Final_Count_Pipeline/BFA_data/Combined_Counts/hBFA1_counts_with_env_info.csv')
hbfa2 = pd.read_csv('../Final_Count_Pipeline/BFA_data/Combined_Counts/hBFA2_counts_with_env_info.csv')
dats = {'dBFA2': dbfa2, 'hBFA1': hbfa1, 'hBFA2': hbfa2}

putative_neuts = dict()
putative_neuts['hBFA1'] = list(hbfa1.loc[hbfa1['Subpool.Environment'] == 'YPD_alpha']['Full.BC'])
putative_neuts['hBFA2'] = list(hbfa2.loc[hbfa2['Subpool.Environment'] == 'CLM_2N']['Full.BC'])
putative_neuts['dBFA2'] = list(dbfa2.loc[dbfa2['Subpool.Environment'] == 'Ancestor_YPD_2N'].loc[dbfa2['Which.Subpools'] == '-R1-1']['Full.BC'])

for b in putative_neuts:
    print(b, 'has', len(putative_neuts[b]), 'control bcs')
    
tp_exclusion = pd.read_csv('bfa_timepoint_exclusion_list.csv')
tp_ex_by_bfa = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # dict like tp_ex[bfa_name][env_name][replicate] = list of excluded tps
for row in tp_exclusion.as_matrix(['ASSAY', 'ENV', 'REP', 'TIME']):
    tp_ex_by_bfa[row[0]][row[1]][row[2]] = row[3].split(';')
print('Example', tp_ex_by_bfa['hBFA1']['FLC4'])

for b in dats:
    measure_fitness(dats[b], putative_neuts[b], tp_ex_by_bfa[b], b)
    cols = [i for i in dats[b].columns if column_is_s_related(i)]
    dats[b][info_cols + cols].to_csv('03_23_18_fitness_estimates/' + b + '_s_03_23_18_all.csv', index=False)
    plot_s_corr(dats[b], b, 's_graphs/s_correlations/', '_all.png')
    
for gc_cutoff in [4, 5]:
    print('Min GC #/26 is', gc_cutoff)
    at_reduced = dict()
    for b in dats:
        td = dats[b]
        td['min.lox.GC.w26'] = td.apply(lambda row: sliding_window_min(row, 26), axis=1)
        at_reduced[b] = td.loc[td['min.lox.GC.w26'] >= gc_cutoff]
        print(b, len(dats[b]), 'bcs,', len(at_reduced[b]), 'meet this cutoff, or', str((100*len(at_reduced[b]))/len(dats[b]))[:2], '%')
        
    at_reduced_neuts = dict()
    at_reduced_neuts['hBFA1'] = list(at_reduced['hBFA1'].loc[at_reduced['hBFA1']['Subpool.Environment'] == 'YPD_alpha']['Full.BC'])
    at_reduced_neuts['hBFA2'] = list(at_reduced['hBFA2'].loc[at_reduced['hBFA2']['Subpool.Environment'] == 'CLM_2N']['Full.BC'])
    at_reduced_neuts['dBFA2'] = list(at_reduced['dBFA2'].loc[at_reduced['dBFA2']['Subpool.Environment'] == 'Ancestor_YPD_2N'].loc[dbfa2['Which.Subpools'] == '-R1-1']['Full.BC'])

    for b in at_reduced_neuts:
        print(b, 'has', len(at_reduced_neuts[b]), 'control bcs')
    
    for b in at_reduced:
        measure_fitness(at_reduced[b], at_reduced_neuts[b], tp_ex_by_bfa[b], b)
        cols = [i for i in at_reduced[b].columns if column_is_s_related(i)]
        at_reduced[b][info_cols + cols].to_csv('03_23_18_fitness_estimates/' + b + '_s_03_23_18_GC_cutoff_' + str(gc_cutoff) + '.csv', index=False)
        plot_s_corr(at_reduced[b], b, 's_graphs/s_correlations/', '_GC_cutoff_' + str(gc_cutoff) + '.png')

hBFA1 has 304 control bcs
hBFA2 has 367 control bcs
dBFA2 has 177 control bcs
Example defaultdict(<class 'list'>, {'R2': ['16']})
Min GC #/26 is 4
dBFA2 5866 bcs, 5530 meet this cutoff, or 94 %
hBFA1 2586 bcs, 2316 meet this cutoff, or 89 %
hBFA2 3802 bcs, 3106 meet this cutoff, or 81 %
hBFA1 has 288 control bcs
hBFA2 has 345 control bcs
dBFA2 has 165 control bcs
Min GC #/26 is 5
dBFA2 5866 bcs, 4314 meet this cutoff, or 73 %
hBFA1 2586 bcs, 1847 meet this cutoff, or 71 %
hBFA2 3802 bcs, 2534 meet this cutoff, or 66 %
hBFA1 has 242 control bcs
hBFA2 has 298 control bcs
dBFA2 has 91 control bcs
