In [1]:
import numpy as np
import pandas as pd
import scipy

In [2]:
pd.options.display.max_rows = 20
pd.set_option('display.max_columns', None)

In [3]:
import os
import sys
module_path = os.path.abspath(os.path.join('../'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [4]:
import importlib; 
import bias_correction; 
importlib.reload( bias_correction ); 
from bias_correction import BiasCorrection

-------

In [5]:
np.random.seed(3469117)

-------

# Case Study Bias Correction

#### Set up bias correction object

In [6]:
debiaser = BiasCorrection(counts='./data/kinsler-2020-preprocessed-data/counts.csv', 
                          samples='./data/kinsler-2020-preprocessed-data/samples.csv', 
                          variants='./data/kinsler-2020-preprocessed-data/variants_withNeutralGroups.csv',
                          config='./analysis-config.json',
                          outdir='./results/')

#### Perform bias correction

In [7]:
debiaser.run()

[ Stage 1: Inferring bias susceptibilities and bias prevalence deviations. ]
    Iteration #1                                                  
    Iteration #2                                                                                                           
    Iteration #3                                                                                                           
    Iteration #4                                                                                                           
    Iteration #5                                                                                                           
[ Stage 2: Inferring bias prevalence trends. ]                                                                             
[ Computing bias-corrected fitness estimates. ]
[ Done. ]


#### Save results of bias correction

In [8]:
# debiaser.save(outdir='./results/')

In [9]:
stop

NameError: name 'stop' is not defined

# Results Figures

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.transforms as mtransforms
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
from matplotlib import cm
import seaborn as sns

In [None]:
sns.set_style('ticks')

In [None]:
plt.rcParams['font.size'] = '7'

In [None]:
plt.rcParams['axes.linewidth'] = 0.66

In [None]:
num_steps = 7
colors = []
center_color = (0.92, 0.92, 0.92)
base_cmap = matplotlib.cm.get_cmap('BrBG')
for step in range(num_steps):
    if(step == int(num_steps/2)):
        colors.append(center_color)
    else:
        colors.append(base_cmap(step/(num_steps-1)))
biasCmapG = matplotlib.colors.LinearSegmentedColormap.from_list('bias_g', colors, N=256)
# display(base_cmap)
display(biasCmapG)

biasCmapG_vmin, biasCmapG_vmax = -1, 1
biasCmapG_norm = matplotlib.colors.Normalize(vmin=biasCmapG_vmin, vmax=biasCmapG_vmax)

In [None]:
num_steps = 7
colors = []
center_color = (0.995, 0.995, 0.995)
base_cmap = matplotlib.cm.get_cmap('BrBG_r')
for step in range(num_steps):
    if(step == int(num_steps/2)):
        colors.append(center_color)
    else:
        colors.append(base_cmap(step/(num_steps-1)))
biasCmapW = matplotlib.colors.LinearSegmentedColormap.from_list('bias_w', colors, N=256)
# display(base_cmap)
display(biasCmapW)

biasCmapW_vmin, biasCmapW_vmax = -1, 1
biasCmapW_norm = matplotlib.colors.Normalize(vmin=biasCmapW_vmin, vmax=biasCmapW_vmax)

In [None]:
num_steps = 7
colors = []
center_color = (0.92, 0.92, 0.92)
base_cmap = matplotlib.cm.get_cmap('RdBu')
for step in range(num_steps):
    if(step == int(num_steps/2)):
        colors.append(center_color)
    else:
        colors.append(base_cmap(step/(num_steps-1)))
residCmapG = matplotlib.colors.LinearSegmentedColormap.from_list('residual_g', colors, N=256)
display(residCmapG)

residCmapG_vmin, residCmapG_vmax = -1, 1
residCmapG_norm = matplotlib.colors.Normalize(vmin=residCmapG_vmin, vmax=residCmapG_vmax)

In [None]:
num_steps = 7
colors = []
center_color = (0.995, 0.995, 0.995)
base_cmap = matplotlib.cm.get_cmap('RdBu_r')
for step in range(num_steps):
    if(step == int(num_steps/2)):
        colors.append(center_color)
    else:
        colors.append(base_cmap(step/(num_steps-1)))
residCmapW = matplotlib.colors.LinearSegmentedColormap.from_list('residual_w', colors, N=256)
# display(base_cmap)
display(residCmapW)

residCmapW_vmin, residCmapW_vmax = -1, 1
residCmapW_norm = matplotlib.colors.Normalize(vmin=residCmapW_vmin, vmax=residCmapW_vmax)

In [None]:
num_steps = 7
colors = []
center_color = (0.92, 0.92, 0.92)
base_cmap = matplotlib.cm.get_cmap('PuOr')
for step in range(num_steps):
    if(step == int(num_steps/2)):
        colors.append(center_color)
    else:
        colors.append(base_cmap(step/(num_steps-1)))
fitnessCmapG = matplotlib.colors.LinearSegmentedColormap.from_list('fitnessual_w', colors, N=256)
# display(base_cmap)
display(fitnessCmapG)

fitnessCmapG_vmin, fitnessCmapG_vmax = -1, 1
fitnessCmapG_norm = matplotlib.colors.Normalize(vmin=fitnessCmapG_vmin, vmax=fitnessCmapG_vmax)

In [None]:
num_steps = 7
colors = []
center_color = (0.995, 0.995, 0.995)
base_cmap = matplotlib.cm.get_cmap('PuOr_r')
for step in range(num_steps):
    if(step == int(num_steps/2)):
        colors.append(center_color)
    else:
        colors.append(base_cmap(step/(num_steps-1)))
fitnessCmapW = matplotlib.colors.LinearSegmentedColormap.from_list('fitnessual_w', colors, N=256)
# display(base_cmap)
display(fitnessCmapW)

fitnessCmapW_vmin, fitnessCmapW_vmax = -1, 1
fitnessCmapW_norm = matplotlib.colors.Normalize(vmin=fitnessCmapW_vmin, vmax=fitnessCmapW_vmax)

In [None]:
palette = sns.color_palette('Set2', 8)
groups_colors = {'Diploid': palette[2], 'GPB2': palette[4], 'PDE2': palette[3], 'Other': palette[-2],  'IRA1': palette[-1]}
palette

------

In [None]:
# VIZ_RESULTS = 'matlab'
VIZ_RESULTS = 'python'

In [None]:
if VIZ_RESULTS == 'matlab':
    viz_variants = matlab_resids_final.loc[:, [s for s in viz_samples if 'TX' not in s and 'T0' not in s]].dropna().index.values
else:
    goodForVizDf = pd.DataFrame(index=debiaser.trustworthy.index, columns=debiaser.trustworthy.columns)
    for assay in debiaser.samplesInfo['assay'].unique():
        assayInfo = debiaser.samplesInfo.loc[(debiaser.samplesInfo['assay'] == assay) & (~debiaser.samplesInfo['timept'].isin(debiaser.cfg['exclude_timepts']))].sort_values(by='timept')
        samples_a = assayInfo['sample'].unique()
        mean_counts_a = debiaser.counts[samples_a].mean(axis=1).values
        goodForVizDf.loc[:, assay] = (mean_counts_a > 25)
    viz_variants = np.where(goodForVizDf.all(axis=1))[0]
    
print(viz_variants, len(viz_variants))

In [None]:
viz_biassusc_final = matlab_biassusc_final if VIZ_RESULTS == 'matlab' else debiaser.variantsInfo['bias_susceptibility'].values

with np.printoptions(threshold=100):
    print(viz_biassusc_final)

In [None]:
viz_biasprev_final = matlab_biasprev_final if VIZ_RESULTS == 'matlab' else pd.DataFrame(dict(zip(debiaser.samplesInfo['sample'], debiaser.samplesInfo['bias_prevalence'])), index=[0])

viz_biasprev_final

In [None]:
viz_resids_orig = matlab_resids_orig if VIZ_RESULTS == 'matlab' else debiaser.log['observed']['residuals']

viz_resids_orig

In [None]:
viz_resids_final = matlab_resids_final if VIZ_RESULTS == 'matlab' else debiaser.residuals

viz_resids_final

In [None]:
if VIZ_RESULTS == 'matlab':
    viz_fitnesses_orig = {group: matlab_fitnesses_orig.loc[debiaser.variantsInfo['neutral_group'] == group] for group in debiaser.variantsInfo['neutral_group'].dropna().unique() if group is not None}
    viz_fitnesses_orig['all'] = matlab_fitnesses_orig.copy()
else:
    viz_fitnesses_orig = {group: debiaser.log['observed']['fitnesses'].loc[[v for v in viz_variants if v in np.where(debiaser.variantsInfo['neutral_group'] == group)[0]]] for group in debiaser.variantsInfo['neutral_group'].dropna().unique() if group is not None}
    viz_fitnesses_orig['all'] = debiaser.log['observed']['fitnesses'].copy()
    
viz_fitnesses_orig['all']

In [None]:
if VIZ_RESULTS == 'matlab':
    viz_fitnesses_final = {group: matlab_fitnesses_final.loc[debiaser.variantsInfo['neutral_group'] == group] for group in debiaser.variantsInfo['neutral_group'].dropna().unique() if group is not None}
    viz_fitnesses_final['all'] = matlab_fitnesses_final.copy()
else:
    viz_fitnesses_final = {group: debiaser.fitnesses.loc[debiaser.variantsInfo['neutral_group'] == group] for group in debiaser.variantsInfo['neutral_group'].dropna().unique() if group is not None}
    viz_fitnesses_final['all'] = debiaser.fitnesses.copy()

viz_fitnesses_final['Diploid']

In [None]:
ECs_fig4 = [23, 21, 20, 18, 13, 3]

In [None]:
viz_assays = np.concatenate([[col for col in viz_fitnesses_orig['Diploid'].columns if f"EC{ECnum}" in col] for ECnum in ECs_fig4])
viz_assays

In [None]:
viz_samples = np.concatenate([[col for col in viz_resids_orig.columns if f"EC{ECnum}" in col] for ECnum in [23, 21, 20, 18, 13, 3]])
viz_samples

In [None]:
variants_GCsorted = viz_resids_final.loc[debiaser.variantsInfo.sort_values(by='barcode_GCratio').index.values, [a for a in viz_samples if 'TX' not in a and 'T0' not in a]].dropna().index.values
print(variants_GCsorted, len(variants_GCsorted))

In [None]:
viz_variants_GCsorted = [v for v in variants_GCsorted if v in viz_variants]
print(viz_variants_GCsorted, len(viz_variants_GCsorted))

In [None]:
viz_fEpsilons = {group: viz_fitnesses_orig[group] - viz_fitnesses_orig[group].mean(axis=0) for group in viz_fitnesses_orig.keys()}
viz_fEpsilons.keys()

In [None]:
viz_fDeltas = {group: viz_fitnesses_final[group] - viz_fitnesses_orig[group] for group in viz_fitnesses_orig.keys()}
viz_fDeltas.keys()

## Figure 4

##### Adapted from the previous 2023-08-21 version (see notebook 2023-05-16_figures.ipynb)

In [None]:
# variants_fig2C = viz_resids_orig.loc[debiaser.variantsInfo.sort_values(by='barcode_GCratio').index.values, [a for a in viz_samples if 'TX' not in a]].dropna().index.values
# variants_fig2C

In [None]:
# viz_resids_final.loc[viz_variants_GCsorted, samples_fig4A_EC[1:]]

In [None]:
# groups_colors

In [None]:
assay_relabels = {
'EC23': 'EC1',
'EC21': 'EC2',
'EC20': 'EC3',
'EC18': 'EC4',
'EC13': 'EC5',
'EC3': 'EC6',
'Baffled': 'Baffled',
'1.4%Gluc': '1.4% Gluc',
'1.6%Gluc': '1.6% Gluc',
'1.8%Gluc': '1.8% Gluc',
'0.5%Raf': '0.5% Raf',
'1.5%Suc1%Raf': '1.5% Suc, 1% Raf',
'0.2MKCl': '0.2M KCl',
'0.5MKCl': '0.5M KCl',
'M3': 'ECBB',
}

In [None]:
##### This here is the new 2023-10-19+ version

figsize = (7, 4.0) 
fig4 = plt.figure(constrained_layout=False, figsize=figsize, dpi=250) # 250

gridsize = (int(figsize[0]*10*2), int(figsize[1]*10*2))
gs = fig4.add_gridspec(gridsize[1], gridsize[0])

gridHeight_v       = 2
gridHeight_A       = int(gridsize[1]*0.45)-1
gridHeight_B       = int(gridsize[1]*0.15)
gridHeight_C       = int(gridsize[1]*0.2)-1
gridHeight_D       = int(gridsize[1]*0.4)-1
gridHeight_E       = int(gridsize[1]*0.4)-1

gridHeight_vAspace = 1
gridHeight_ABspace = int(gridsize[1]*0.115)+1
gridHeight_BCspace = int(gridsize[1]*0.085)+1
gridHeight_DEspace = int(1)

gridWidth_uAspace  = 1
gridWidth_ECspace  = 1

gridWidth_u        = 1
gridWidth_ABC      = int(gridsize[0]*0.70)
gridWidth_DE       = int(gridsize[0]*0.22)
gridWidth_EC       = int(gridWidth_ABC/len(ECs_fig4)) - gridWidth_ECspace
gridWidth_cbar     = 1

ax4A_u    = fig4.add_subplot(gs[(gridHeight_v + gridHeight_vAspace):(gridHeight_v + gridHeight_vAspace + gridHeight_A), :gridWidth_u])
ax4Av_ECs = []
ax4A_ECs  = []
ax4B_ECs  = []
ax4C_ECs  = []
for i in range(len(ECs_fig4)):
    gridx_start = (gridWidth_u + gridWidth_uAspace + i*gridWidth_EC + i*gridWidth_ECspace)
    gridx_end   = gridx_start + gridWidth_EC
    ax4Av_ECs.append( fig4.add_subplot(gs[0:gridHeight_v, gridx_start:gridx_end]) )
    ax4A_ECs.append( fig4.add_subplot(gs[(gridHeight_v + gridHeight_vAspace):(gridHeight_v + gridHeight_vAspace + gridHeight_A), gridx_start:gridx_end]) )
    ax4B_ECs.append( fig4.add_subplot(gs[(-gridHeight_C - gridHeight_BCspace - gridHeight_B):(-gridHeight_C - gridHeight_BCspace), gridx_start:gridx_end], sharex=ax4B_ECs[0] if i!=0 else None) )
    ax4C_ECs.append( fig4.add_subplot(gs[(-gridHeight_C):, gridx_start:gridx_end], sharex=ax4C_ECs[0] if i!=0 else None) )
ax4A_cbarBias = fig4.add_subplot(gs[(gridHeight_v + gridHeight_vAspace):(gridHeight_v + gridHeight_vAspace + int(gridHeight_A/2)-1), (gridx_end + gridWidth_ECspace)+1:(gridx_end + gridWidth_ECspace + gridWidth_cbar)+1])
ax4A_cbarResids = fig4.add_subplot(gs[(gridHeight_v + gridHeight_vAspace + int(gridHeight_A/2)+1):(gridHeight_v + gridHeight_vAspace + gridHeight_A), (gridx_end + gridWidth_ECspace)+1:(gridx_end + gridWidth_ECspace + gridWidth_cbar)+1])

ax4D = fig4.add_subplot(gs[(1 + gridHeight_v + gridHeight_vAspace):(1 + gridHeight_v + gridHeight_vAspace + gridHeight_D), -gridWidth_DE:])
ax4E = fig4.add_subplot(gs[-(1 + gridHeight_E):-(1+1), -gridWidth_DE:])

for ax in ax4A_ECs + ax4Av_ECs + [ax4A_u, ax4D, ax4E]: 
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

sns.heatmap(ax=ax4A_u, data=np.atleast_2d(viz_biassusc_final[viz_variants_GCsorted]).T, cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, cbar=False)
ax4A_u.set_xticks([])
ax4A_u.set_yticks([])
# ax4A_u.set_ylabel("bias susceptibility", labelpad=0, fontsize=6)
ax4A_u.text(-2.1, len(viz_biassusc_final[viz_variants_GCsorted])/2, "inferred bias susceptibility", c='k', verticalalignment='center', horizontalalignment='center', rotation=90, fontsize=6, weight='bold', zorder=100)
for spine in ax4A_u.spines.values():
    spine.set(visible=True, lw=.4, edgecolor="black")

for c, ECnum in enumerate(ECs_fig4):
    
    samples_fig4A_EC = [col for col in viz_resids_final.columns if f"EC{ECnum}" in col]
    print(samples_fig4A_EC)
    # print("aaa")
    #----------
    hm_bias = sns.heatmap(ax=ax4Av_ECs[c], data=viz_biasprev_final[samples_fig4A_EC[1:]], cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, square=False, linewidths=0.0, linecolor='w',
                          cbar=(c==len(ECs_fig4)-1), cbar_ax=(ax4A_cbarBias if c==len(ECs_fig4)-1 else None), cbar_kws=({'label':"bias", 'drawedges':False, 'ticks':[-1, 1]} if c==len(ECs_fig4)-1 else None))
    ax4Av_ECs[c].set_xticks([])
    ax4Av_ECs[c].set_yticks([])
    ax4Av_ECs[c].add_patch(mpatches.Rectangle(xy=(0, 0), width=15, height=0.4, fc='w', ec='w', lw=0.0, clip_on=False, zorder=99))
    ax4Av_ECs[c].add_patch(mpatches.Rectangle(xy=(0*len(samples_fig4A_EC)/3+0.02, 0.405), width=(len(samples_fig4A_EC)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    ax4Av_ECs[c].add_patch(mpatches.Rectangle(xy=(1*len(samples_fig4A_EC)/3+0.02, 0.405), width=(len(samples_fig4A_EC)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    ax4Av_ECs[c].add_patch(mpatches.Rectangle(xy=(2*len(samples_fig4A_EC)/3+0.02, 0.405), width=(len(samples_fig4A_EC)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))    
    # for spine in ax4Av_ECs[c].spines.values():
    #     spine.set(visible=True, lw=.3, edgecolor="black")
    # print("bbb")
    #----------
    hm_resids = sns.heatmap(ax=ax4A_ECs[c], data=viz_resids_final.loc[viz_variants_GCsorted, samples_fig4A_EC[1:]], cmap=residCmapW, center=0, vmin=-1, vmax=1,  #.shift(-1, axis=1)
                            cbar=(c==len(ECs_fig4)-1), cbar_ax=(ax4A_cbarResids if c==len(ECs_fig4)-1 else None), cbar_kws=({'label':"residuals", 'drawedges':False, 'ticks':[-1, 1]} if c==len(ECs_fig4)-1 else None))
    ax4A_ECs[c].set_xticks([])
    ax4A_ECs[c].set_yticks([])
    # print("ccc")
    #----------
    ax4A_ECs[c].text(len(samples_fig4A_EC[:-1])/2, len(viz_variants_GCsorted)+45, assay_relabels[f"EC{ECnum}"], c='k', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    ax4A_ECs[c].text(2.25 if len(samples_fig4A_EC[:-1]) == 14 else 1.75, len(viz_variants_GCsorted)+7, f"rep1", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    ax4A_ECs[c].text(len(samples_fig4A_EC[:-1])/2, len(viz_variants_GCsorted)+7, f"rep2", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    ax4A_ECs[c].text(12.25 if len(samples_fig4A_EC[:-1]) == 14 else 9.75, len(viz_variants_GCsorted)+7, f"rep3", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    # ax4A_ECs[c].plot([0, len(samples_fig4A_EC[:-1])], [len(viz_variants_GCsorted)+32, len(viz_variants_GCsorted)+32], c='k', lw=0.5, clip_on=False)
    ax4A_ECs[c].add_patch(mpatches.Rectangle(xy=(0, len(viz_variants_GCsorted)+38), width=len(samples_fig4A_EC[:-1]), height=32, fc='#eee', ec='k', lw=0.0, clip_on=False, zorder=1))
    # print("ddd")
    
    # break
    
ax4Av_ECs[3].text(0, -0.7, "inferred bias prevalence", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, weight='bold', zorder=100)

# ax4A_cbarBias.yaxis.label.set_size(4)
# ax4A_cbarBias.yaxis.get_label().set(fontsize=4, position=(0, 0))

ax4A_cbarBias.set_yticks([-1, 1])
ax4A_cbarBias.set_yticklabels([-1, 1], fontsize=4)
ax4A_cbarBias.set_ylabel("bias", fontsize=5, labelpad=-3)
ax4A_cbarBias.tick_params(width=0.5, length=1.33, pad=1)
ax4A_cbarResids.set_yticks([-1, 1])
ax4A_cbarResids.set_yticklabels([-1, 1], fontsize=4)
ax4A_cbarResids.set_ylabel("bias-adj. residuals", fontsize=5, labelpad=-3)
ax4A_cbarResids.tick_params(width=0.5, length=1.33, pad=1)
for spine in ax4A_cbarBias.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")
for spine in ax4A_cbarResids.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")
    
ax4A_u.text(-4, -55, "A", verticalalignment='top', horizontalalignment='left', fontsize=7, weight='bold', rotation=0)
ax4A_u.text(-4, 425, "B", verticalalignment='top', horizontalalignment='left', fontsize=7, weight='bold', rotation=0)
ax4A_u.text(-4, 610, "C", verticalalignment='top', horizontalalignment='left', fontsize=7, weight='bold', rotation=0)
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

vizGroups_fig4B = ['Diploid', 'GPB2', 'PDE2']

for c, ECnum in enumerate(ECs_fig4):
    viz_assays_EC = [col for col in viz_fitnesses_final['all'].columns if f"EC{ECnum}" in col]
    print(viz_assays_EC)
    
    boxplot_fitnesses = []
    
    box_positions = [0.8, 1.2, 1.8, 2.2, 2.8, 3.2]
    
    for g, group in enumerate(vizGroups_fig4B):
        viz_fitnesses_orig_g_a = np.array([])
        viz_fitnesses_final_g_a = np.array([])
        for a, assay in enumerate(viz_assays_EC):
            viz_fitnesses_orig_g_a  = np.hstack([viz_fitnesses_orig_g_a, viz_fitnesses_orig[group].loc[:, assay].values])
            viz_fitnesses_final_g_a = np.hstack([viz_fitnesses_final_g_a, viz_fitnesses_final[group].loc[:, assay].values])
            
        pvalue = scipy.stats.levene(viz_fitnesses_orig_g_a, viz_fitnesses_final_g_a).pvalue
        if(pvalue < 0.05):
            ax4B_ECs[c].text(np.percentile(viz_fitnesses_final_g_a, 99)+(0.025 if ECnum == 20 and group == 'PDE2' else 0.125), box_positions[::-1][2*g+1]-0.15, "*", verticalalignment='center', horizontalalignment='center', fontsize=6, color=groups_colors[group])
            
        boxplot_fitnesses.append(viz_fitnesses_orig_g_a)
        boxplot_fitnesses.append(viz_fitnesses_final_g_a)
    
    bplot = ax4B_ECs[c].boxplot(boxplot_fitnesses[::-1], vert=False, positions=box_positions, widths=0.25, showcaps=False, showfliers=False, patch_artist=True,
                                boxprops=dict(linewidth=0.66, facecolor='r'), whiskerprops=dict(linewidth=0.66), medianprops=dict(linewidth=0.66))
    
    boxshadealpha = 0.5
    for i, patch in enumerate(bplot['boxes'][::-1]):
        patch.set_facecolor(list(groups_colors[vizGroups_fig4B[int(i/2)]])+[boxshadealpha*(i%2)])
        patch.set_edgecolor(groups_colors[vizGroups_fig4B[int(i/2)]])
    for i, patch in enumerate(bplot['whiskers'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/4)]])
    for i, patch in enumerate(bplot['medians'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/2)]])
        
    # ax4B_ECs[c].text(0, 0.8, "pre", verticalalignment='center', horizontalalignment='center', fontsize=5, zorder=200, bbox=dict(pad=1.0, fc='#ddd', ec='#aaa', lw=0.6))
            
    # fdist_orig = df_fig4B_a[(df_fig4B_a['condition'] == condition) & (df_fig4B_a['group'] == group) & (df_fig4B_a['orig_final'] == 'orig')]['fitness'].values
    # fdist_final = df_fig4B_a[(df_fig4B_a['condition'] == condition) & (df_fig4B_a['group'] == group) & (df_fig4B_a['orig_final'] == 'final')]['fitness'].values
    # sig_var  = 0.05 > scipy.stats.levene(fdist_orig, fdist_final, center='mean').pvalue
    # sig_mean = 0.05 > scipy.stats.ttest_ind(fdist_orig, fdist_final, equal_var=sig_var).pvalue
    # print(ECnum, sig_var, sig_mean)
    
    ax4B_ECs[c].set_xticks([0, 1])
    ax4B_ECs[c].set_xticklabels([0, 1], fontsize=5)
    ax4B_ECs[c].set_xlabel('')
    ax4B_ECs[c].set_yticks(box_positions if c==0 else []) # if c!=0 else [1, 2, 3])
    ax4B_ECs[c].set_yticklabels(['pre', 'post', 'pre', 'post', 'pre', 'post'][::-1] if c==0 else [], fontsize=5) # if c!=0 else ['Dip.', 'GPB2', 'PDE2'], fontsize=4)    
    ax4B_ECs[c].set_ylabel('')    
    ax4B_ECs[c].tick_params(width=0.5, length=1.33, pad=2)
    
    # ax4B_ECs[c].get_legend().remove()
    
    # ax4B_ECs[c].spines[['right', 'top', 'left']].set_visible(False)
    
    ax4B_ECs[c].set_facecolor('#fff')
    
    for spine in ax4B_ECs[c].spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    
    # break
    
ax4B_ECs[3].text(-0.25, -0.875, "fitness estimates", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

for c, ECnum in enumerate(ECs_fig4):
    viz_assays_EC = [col for col in viz_fitnesses_final['all'].columns if f"EC{ECnum}" in col]
    print(viz_assays_EC)
    
    for g, group in enumerate(['Diploid', 'GPB2', 'PDE2']):
        
        viz_fEpsilons_g_allReps = np.concatenate([viz_fEpsilons[group].loc[:, assay].values for assay in viz_assays_EC])
        viz_fDeltas_g_allReps   = np.concatenate([viz_fDeltas[group].loc[:, assay].values for assay in viz_assays_EC])
    
        sns.regplot(ax=ax4C_ECs[c], x=viz_fEpsilons_g_allReps, y=-1*viz_fDeltas_g_allReps, ci=0, 
                    color=groups_colors[group], marker=('o' if 'Diploid' in group else 'o'), label=f"{group}", 
                    scatter_kws={'alpha': 0.33 if group=='Diploid' else 1.0, 's': 1, 'lw': 0, 'zorder': g}, line_kws={'lw': 0.66, 'zorder': g})
        
    # ax4C_ECs[c].set_xticks([])
    ax4C_ECs[c].set_xlabel('')
    # ax4C_ECs[c].set_yticks([])
    ax4C_ECs[c].set_ylabel('')    
    # ax4C_ECs[c].get_legend().remove()
    
    ax4C_ECs[c].set_xlim((-1.05, 1.05))
    ax4C_ECs[c].set_xticks([-1, 0, 1])
    ax4C_ECs[c].set_xticklabels([-1, 0, 1], fontsize=5)
    ax4C_ECs[c].set_ylim((-1.05, 1.05))
    ax4C_ECs[c].set_yticks([-1, 0, 1] if c==0 else [])
    ax4C_ECs[c].set_yticklabels([-1, 0, 1] if c==0 else [], fontsize=5)
    
    ax4C_ECs[c].tick_params(width=0.5, length=1.33, pad=2)
    
    for spine in ax4C_ECs[c].spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    
    # ax4C_ECs[c].spines['left'].set_bounds(-1, 1)
    # ax4C_ECs[c].spines['bottom'].set_bounds(-1, 1)
    # if(c == 0):
    #     ax4C_ECs[c].spines[['right', 'top']].set_visible(False)
    # else:
    #     ax4C_ECs[c].spines[['right', 'top', 'left']].set_visible(False)
    
            
    if(c == len(ECs_fig4)-1):
        ax4C_ECs[c].text(1.0, 1.39, "$~~~~~$Diploid$~~~~~~$GPB2$~~~~~~$PDE2", verticalalignment='center', horizontalalignment='right', fontsize=4.5, zorder=100, bbox=dict(pad=1.25, fc='w', ec='#ddd', lw=0.5), clip_on=False)
        ax4C_ECs[c].plot([0.10, 0.30], [1.39, 1.39], c=groups_colors['PDE2'], lw=1, zorder=200, clip_on=False)
        ax4C_ECs[c].plot([-0.95, -0.75], [1.39, 1.39], c=groups_colors['GPB2'], lw=1, zorder=200, clip_on=False)
        ax4C_ECs[c].plot([-2.09, -1.89], [1.39, 1.39], c=groups_colors['Diploid'], lw=1, zorder=200, clip_on=False)
            
    # break    
    

ax4C_ECs[0].set_ylabel("correction ($\Delta f$)", labelpad=0, fontsize=6)        

ax4C_ECs[3].text(-1.1, -1.65, "initial fitness misestimation ($\delta f$)", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)
    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

for g, group in enumerate(['Diploid', 'GPB2', 'PDE2']):
    
    viz_fEpsilons_g = viz_fEpsilons[group]

    U_g, s_g, VT_g = np.linalg.svd(viz_fEpsilons_g.dropna(axis=1).values[:, :])
    
    u_svd_g    = (s_g[0])*U_g[:, 0]
    u_method_g = viz_biassusc_final[viz_fEpsilons_g.index.tolist()]
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress((u_svd_g), (u_method_g))
    # print("u_svd X u_method_g:", group, "m =", slope, "intc =", intercept, "r^2 =", r_value**2, "p =", p_value, "stderr =", std_err)
    
    if('Diploid' in group):
        ax4D.plot((-1 if slope<0 else 1)*(u_svd_g), (u_method_g), marker='+', ms=2, mew=0.5, ls='none', alpha=0.33, c=groups_colors[group], label=f"{group}", zorder=-100)
    elif('IRA' not in group):  
        sns.regplot(ax=ax4D, x=(-1 if slope<0 else 1)*(u_svd_g), y=(u_method_g), ci=0, label=f"{group} ($r^2 = ${r_value**2:.2f})",
                   color=groups_colors[group], marker=('+' if 'Diploid' in group else 'o'), 
                    scatter_kws={'alpha': 0.5, 's': 4, 'lw': 0, 'zorder': -g}, line_kws={'lw': 1.0, 'zorder': -g})
    
    gcratio = debiaser.variantsInfo.loc[viz_fEpsilons_g.index.tolist(), 'barcode_GCratio'].values
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(1-(gcratio), (u_method_g))
    # print("gcratio X u_method_g:", group, "m =", slope, "intc =", intercept, "r^2 =", r_value**2, "p =", p_value, "stderr =", std_err)

    if('Diploid' in group):
        ax4E.plot((1 if group != 'PDE2' else -1)*(u_svd_g), 1-(gcratio), marker='+', ms=2, mew=0.5, ls='none', alpha=0.33, c=groups_colors[group], label=f"{group}", zorder=-100)
    elif('IRA' not in group):  
        sns.regplot(ax=ax4E, y=1-(gcratio), x=(-1 if group != 'PDE2' else 1)*(u_svd_g), ci=0, label=f"{group} ($r^2 = ${r_value**2:.2f})", 
                   color=groups_colors[group], marker=('+' if 'Diploid' in group else 'o'), 
                    scatter_kws={'alpha': 0.5, 's': 4, 'lw': 0, 'zorder': -g}, line_kws={'lw': 1.0, 'zorder': -g})

ax4E.set_xlabel("$\hat{u}_i^{~(SVD)}$", labelpad=0, fontsize=6) # 
ax4E.set_xticks([-1, 0, 1, 2])
ax4E.set_xticklabels([-1, 0, 1, 2], fontsize=5)
ax4E.set_ylabel("1 - GC-ratio", labelpad=-2, fontsize=6)
ax4E.set_ylim((0, 1))
ax4E.set_yticks([0, 0.5, 1], fontsize=5)
ax4E.set_yticklabels([0, '', 1], fontsize=5)
legend4BC1 = ax4E.legend(scatterpoints=1, scatteryoffsets=[0], frameon=False, loc='lower right',  handletextpad=-0.25, labelspacing=0.2, fontsize=5)
for labeltext in legend4BC1.get_texts(): 
    if(labeltext.get_text() != 'Diploid'):
        labeltext.set_va('center_baseline')

ax4E.tick_params(width=0.5, length=2, pad=2)
for spine in ax4E.spines.values():
    spine.set_edgecolor('#000')
    spine.set_linewidth(0.5)
        
ax4D.set_xlabel("$\hat{u}_i^{~(SVD)}$", labelpad=0, fontsize=6) #^{~(SVD)}
ax4D.set_ylabel("$\hat{u}_{i}^{~(method)}$", labelpad=-2, fontsize=6) # ^{~(method)}
ax4D.set_xlim((-1.5, 2.5))
ax4D.set_ylim((-1.5, 2.5))
ax4D.set_xticks([-1, 0, 1, 2])
ax4D.set_yticks([-1, 0, 1, 2])
ax4D.set_xticklabels([-1, 0, 1, 2], fontsize=5)
ax4D.set_yticklabels([-1, 0, 1, 2], fontsize=5)
ax4D.legend(frameon=False, loc='upper left', fontsize=5)
legend4BC0 = ax4D.legend(scatterpoints=1,scatteryoffsets=[0], frameon=False, loc='upper left', bbox_to_anchor=(-0.05, 1.0), handletextpad=-0.25, labelspacing=0.2, fontsize=5)
for labeltext in legend4BC0.get_texts(): 
    if(labeltext.get_text() != 'Diploid'):
        labeltext.set_va('center_baseline')

ax4D.tick_params(width=0.5, length=2, pad=2)
for spine in ax4D.spines.values():
    spine.set_edgecolor('#000')
    spine.set_linewidth(0.5)
    
ax4E.text(-1.5, 2.575, "D", verticalalignment='top', horizontalalignment='left', fontsize=7, weight='bold', rotation=0)
ax4E.text(-1.5, 1.125, "E", verticalalignment='top', horizontalalignment='left', fontsize=7, weight='bold', rotation=0)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    
# ax4A_ECs[0].add_patch(mpatches.Rectangle(xy=(-10, 0), width=1, height=10, fc='#000', ec='w', lw=0.0, clip_on=False, zorder=-999))
        
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

plt.show()


In [None]:
stop

In [None]:
# VIZ_NOTES = "assaysmatchingjupyterpreprocessing"
# VIZ_NOTES = "skipinitialbaselineoptimization"
# VIZ_NOTES = "inferbiassuscforallvariants-vizvariantswithmeancountabove25inallassays-weightsequalcounts-weightsaturationequalsmintrustthresh-5iters-fixcbars"
VIZ_NOTES = "relabelingassays"

In [None]:
date

In [None]:
# fig4.savefig(f"fig4_{date}_{VIZ_RESULTS}_{VIZ_NOTES}.png", dpi=300)

## Supp Figs

In [None]:
assaysets_figSI = ['EC23', 'EC21', 'EC20', 'EC18', 'EC13', 'EC3']
assaysets_figSI

In [None]:
viz_assays = np.concatenate([[col for col in viz_fitnesses_orig['Diploid'].columns if f"EC{ECnum}" in col] for ECnum in ECs_fig4])
viz_assays

In [None]:
viz_samples = np.concatenate([[col for col in viz_resids_orig.columns if f"EC{ECnum}" in col] for ECnum in [23, 21, 20, 18, 13, 3]])
viz_samples
print(len(viz_samples))

In [None]:
figsize = (7.0, 7.0) 
figSI = plt.figure(constrained_layout=True, figsize=figsize, dpi=200) # 250

gridsize = (int(figsize[0]*10), int(figsize[1]*10))
gs = figSI.add_gridspec(gridsize[1], gridsize[0], wspace=0.0, hspace=0.0, width_ratios=[1]*gridsize[1], height_ratios=[1]*gridsize[0])


grid = {'assaysetlabel': {'y0': 0,  'yf': 2},
        'residsorig':    {'y0': 3,  'yf': 22},
        'biasprev':      {'y0': 26, 'yf': 27},
        'biassusc':      {'y0': 28, 'yf': 47},
        'residsfinal':   {'y0': 28, 'yf': 47},
        'fitnessdistns': {'y0': 50, 'yf': 57},
        'fitnessdeltas': {'y0': 61, 'yf': gridsize[1]}}

# ax_wut = figSI.add_subplot(gs[0:1, 0:1])
# ax_wut2 = figSI.add_subplot(gs[1:2, 1:2])

ax_biassusc    = figSI.add_subplot(gs[grid['biassusc']['y0']:grid['biassusc']['yf'], 0:1])
 

    

# ax_cbar_biasorig    = figSI.add_subplot(gs[grid['residsorig']['y0']:13, -1:])
ax_cbar_resids  = figSI.add_subplot(gs[16:grid['residsorig']['yf'], -2:-1])
ax_cbar_bias   = figSI.add_subplot(gs[grid['residsfinal']['y0']:34, -2:-1])
# ax_cbar_residsfinal = figSI.add_subplot(gs[39:grid['residsfinal']['yf'], -1:])


for a, assayset in enumerate(assaysets_figSI):
    print(assayset)
    
    # if(a != 0 and a != 5):
    #     continue
    
    col_w = 10
    col_x = 2 + (a*col_w)+(a*1)
    
    print(f"{col_x}:{col_x+col_w}")
    
    ax_assaysetlabel = figSI.add_subplot(gs[grid['assaysetlabel']['y0']:grid['assaysetlabel']['yf'], col_x:col_x+col_w])
    ax_residsorig    = figSI.add_subplot(gs[grid['residsorig']['y0']:grid['residsorig']['yf'], col_x:col_x+col_w])
    ax_biasprev      = figSI.add_subplot(gs[grid['biasprev']['y0']:grid['biasprev']['yf'], col_x:col_x+col_w])
    ax_residsfinal   = figSI.add_subplot(gs[grid['residsfinal']['y0']:grid['residsfinal']['yf'], col_x:col_x+col_w])
    ax_fitnessdistns = figSI.add_subplot(gs[grid['fitnessdistns']['y0']:grid['fitnessdistns']['yf'], col_x:col_x+col_w])
    ax_fitnessdeltas = figSI.add_subplot(gs[grid['fitnessdeltas']['y0']:grid['fitnessdeltas']['yf']:, col_x:col_x+col_w])
    
    assays_figSI_aset  = [col for col in viz_fitnesses_final['all'].columns if assayset in col]
    samples_figSI_aset = [col for col in viz_resids_final.columns if assayset in col][1:]
    print(assays_figSI_aset)
    print(samples_figSI_aset)
    
    #--------------------
    
    ax_assaysetlabel.text(0.5, 0.5, assay_relabels[f"{assayset}"], c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)    
    ax_assaysetlabel.set_facecolor('#eee')
    ax_assaysetlabel.set_xticks([])
    ax_assaysetlabel.set_yticks([])
    for spine in ax_assaysetlabel.spines.values():
        spine.set(visible=True, lw=0, edgecolor="white")
    
    #--------------------
    
    hm_residsorig = sns.heatmap(ax=ax_residsorig, data=viz_resids_orig.loc[viz_variants_GCsorted, samples_figSI_aset], cmap=residCmapW, center=0, vmin=-1, vmax=1,  #.shift(-1, axis=1)
                                cbar=(a==0), cbar_ax=ax_cbar_resids, cbar_kws={'label':"residuals", 'drawedges':False, 'ticks':[-1, 1], 'aspect': 20})
    
    ax_residsorig.set_xticks([])
    ax_residsorig.set_yticks([])
    
    ax_residsorig.text(2.75 if len(samples_figSI_aset) == 14 else 1.75, len(viz_variants_GCsorted)+7, f"rep1", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    ax_residsorig.text(len(samples_figSI_aset)/2, len(viz_variants_GCsorted)+7, f"rep2", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    ax_residsorig.text(12.25 if len(samples_figSI_aset) == 14 else 9.75, len(viz_variants_GCsorted)+7, f"rep3", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    
    if(a == 5):
        ax_residsorig.text(15, 0, "observed residuals", c='k', verticalalignment='top', horizontalalignment='left', rotation=90, fontsize=6, weight='bold', zorder=100)    
    
    #--------------------
    
    hm_biasprev = sns.heatmap(ax=ax_biasprev, data=viz_biasprev_final[samples_figSI_aset], cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, square=False, linewidths=0.0, linecolor='w',
                              cbar=(a==0), cbar_ax=ax_cbar_bias, cbar_kws={'label':"bias", 'drawedges':False, 'ticks':[-1, 1]})
    
    ax_biasprev.set_xticks([])
    ax_biasprev.set_yticks([])
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(0, 0), width=15, height=0.4, fc='w', ec='w', lw=0.0, clip_on=False, zorder=99))
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(0*(len(samples_figSI_aset)+1)/3+0.02, 0.405), width=((len(samples_figSI_aset)+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(1*(len(samples_figSI_aset)+1)/3+0.02, 0.405), width=((len(samples_figSI_aset)+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(2*(len(samples_figSI_aset)+1)/3+0.02, 0.405), width=((len(samples_figSI_aset)+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))    
    
    if(a == 3):
        ax_biasprev.text(0, -0.7, "inferred bias prevalence", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, weight='bold', zorder=100)
    
    #--------------------
    
    hm_residsfinal = sns.heatmap(ax=ax_residsfinal, data=viz_resids_final.loc[viz_variants_GCsorted, samples_figSI_aset], cmap=residCmapW, center=0, vmin=-1, vmax=1, cbar=False)
    
    ax_residsfinal.set_xticks([])
    ax_residsfinal.set_yticks([])
    
    ax_residsfinal.text(2.75 if len(samples_figSI_aset) == 14 else 1.75, len(viz_variants_GCsorted)+7, f"rep1", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    ax_residsfinal.text(len(samples_figSI_aset)/2, len(viz_variants_GCsorted)+7, f"rep2", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    ax_residsfinal.text(12.25 if len(samples_figSI_aset) == 14 else 9.75, len(viz_variants_GCsorted)+7, f"rep3", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    
    if(a == 5):
        ax_residsfinal.text(15, len(viz_biassusc_final[viz_variants_GCsorted]), "bias-adjusted residuals", c='k', verticalalignment='bottom', horizontalalignment='left', rotation=90, fontsize=6, weight='bold', zorder=100)
    
    #--------------------
    
    boxplot_fitnesses = []
    
    box_positions = [0.8, 1.2, 1.8, 2.2, 2.8, 3.2]
    
    for g, group in enumerate(vizGroups_fig4B):
        viz_fitnesses_orig_g = np.array([])
        viz_fitnesses_final_g = np.array([])
        for ai, assay in enumerate(assays_figSI_aset):
            viz_fitnesses_orig_g  = np.hstack([viz_fitnesses_orig_g, viz_fitnesses_orig[group].loc[:, assay].values])
            viz_fitnesses_final_g = np.hstack([viz_fitnesses_final_g, viz_fitnesses_final[group].loc[:, assay].values])
            
        pvalue = scipy.stats.levene(viz_fitnesses_orig_g, viz_fitnesses_final_g).pvalue
        if(pvalue < 0.05):
            ax_fitnessdistns.text(np.percentile(viz_fitnesses_final_g, 99)+(0.125), box_positions[::-1][2*g+1]-0.15, "*", verticalalignment='center', horizontalalignment='center', fontsize=6, color=groups_colors[group])
            
        boxplot_fitnesses.append(viz_fitnesses_orig_g)
        boxplot_fitnesses.append(viz_fitnesses_final_g)
    
    bplot = ax_fitnessdistns.boxplot(boxplot_fitnesses[::-1], vert=False, positions=box_positions, widths=0.25, showcaps=False, showfliers=False, patch_artist=True,
                                boxprops=dict(linewidth=0.66, facecolor='r'), whiskerprops=dict(linewidth=0.66), medianprops=dict(linewidth=0.66))
    
    boxshadealpha = 0.5
    for i, patch in enumerate(bplot['boxes'][::-1]):
        patch.set_facecolor(list(groups_colors[vizGroups_fig4B[int(i/2)]])+[boxshadealpha*(i%2)])
        patch.set_edgecolor(groups_colors[vizGroups_fig4B[int(i/2)]])
    for i, patch in enumerate(bplot['whiskers'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/4)]])
    for i, patch in enumerate(bplot['medians'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/2)]])
    
    ax_fitnessdistns.set_xlim((-1.05, 1.05))
    ax_fitnessdistns.set_xticks([-1, 0, 1])
    ax_fitnessdistns.set_xticklabels([-1, 0, 1], fontsize=5)
    ax_fitnessdistns.set_xlabel('')
    ax_fitnessdistns.set_yticks(box_positions if a==0 else []) # if c!=0 else [1, 2, 3])
    ax_fitnessdistns.set_yticklabels(['pre', 'post', 'pre', 'post', 'pre', 'post'][::-1] if a==0 else [], fontsize=5) 
    ax_fitnessdistns.set_ylabel('')    
    ax_fitnessdistns.tick_params(width=0.5, length=1.33, pad=2)
    
    ax_fitnessdistns.set_facecolor('#fff')
    
    for spine in ax_fitnessdistns.spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    

    
    if(a == 3):
        ax_fitnessdistns.text(-0.6, -0.835, "fitness estimates", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)

    #--------------------
    
    for g, group in enumerate(['Diploid', 'GPB2', 'PDE2']):
        
        viz_fEpsilons_g_allReps = np.concatenate([viz_fEpsilons[group].loc[:, assay].values for assay in assays_figSI_aset])
        viz_fDeltas_g_allReps   = np.concatenate([viz_fDeltas[group].loc[:, assay].values for assay in assays_figSI_aset])
    
        sns.regplot(ax=ax_fitnessdeltas, x=viz_fEpsilons_g_allReps, y=-1*viz_fDeltas_g_allReps, ci=0, 
                    color=groups_colors[group], marker=('o' if 'Diploid' in group else 'o'), label=f"{group}", 
                    scatter_kws={'alpha': 0.33 if group=='Diploid' else 1.0, 's': 1, 'lw': 0, 'zorder': g}, line_kws={'lw': 0.66, 'zorder': g})
        
    # ax_fitnessdeltas.set_xticks([])
    ax_fitnessdeltas.set_xlabel('')
    # ax_fitnessdeltas.set_yticks([])
    ax_fitnessdeltas.set_ylabel('')    
    # ax_fitnessdeltas.get_legend().remove()
    
    ax_fitnessdeltas.set_xlim((-1.05, 1.05))
    ax_fitnessdeltas.set_xticks([-1, 0, 1])
    ax_fitnessdeltas.set_xticklabels([-1, 0, 1], fontsize=5)
    ax_fitnessdeltas.set_ylim((-1.05, 1.05))
    ax_fitnessdeltas.set_yticks([-1, 0, 1] if c==0 else [])
    ax_fitnessdeltas.set_yticklabels([-1, 0, 1] if c==0 else [], fontsize=5)
    
    ax_fitnessdeltas.tick_params(width=0.5, length=1.33, pad=2)
    
    for spine in ax_fitnessdeltas.spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    
    if(a == len(assaysets_figSI)-1):
        ax_fitnessdeltas.text(1.0, 1.39, "$~~~~~$Diploid$~~~~~~$GPB2$~~~~~~$PDE2", verticalalignment='center', horizontalalignment='right', fontsize=4.5, zorder=100, bbox=dict(pad=1.25, fc='w', ec='#ddd', lw=0.5), clip_on=False)
        ax_fitnessdeltas.plot([0.30, 0.50], [1.39, 1.39], c=groups_colors['PDE2'], lw=1, zorder=200, clip_on=False)
        ax_fitnessdeltas.plot([-0.5, -0.3], [1.39, 1.39], c=groups_colors['GPB2'], lw=1, zorder=200, clip_on=False)
        ax_fitnessdeltas.plot([-1.35, -1.15], [1.39, 1.39], c=groups_colors['Diploid'], lw=1, zorder=200, clip_on=False)
            
    if(a == 0):
        ax_fitnessdeltas.set_ylabel("correction ($\Delta f$)", labelpad=0, fontsize=6)        
    if(a == 3):
        ax_fitnessdeltas.text(-1.1, -1.6, "initial fitness misestimation ($\delta f$)", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)
    
    

    

ax_cbar_resids.set_yticks([-1, 1])
ax_cbar_resids.set_yticklabels([-1, 1], fontsize=4)
ax_cbar_resids.set_ylabel("residuals", fontsize=5, labelpad=-3)
ax_cbar_resids.tick_params(width=0.5, length=1.33, pad=1)
ax_cbar_bias.set_yticks([-1, 1])
ax_cbar_bias.set_yticklabels([-1, 1], fontsize=4)
ax_cbar_bias.set_ylabel("bias", fontsize=5, labelpad=-3)
ax_cbar_bias.tick_params(width=0.5, length=1.33, pad=1)
for spine in ax_cbar_resids.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")
for spine in ax_cbar_bias.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")



sns.heatmap(ax=ax_biassusc, data=np.atleast_2d(viz_biassusc_final[viz_variants_GCsorted]).T, cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, cbar=False)
ax_biassusc.set_xticks([])
ax_biassusc.set_yticks([])
ax_biassusc.set_xticklabels([])
ax_biassusc.set_yticklabels([])
ax_biassusc.set_xlabel('')
ax_biassusc.set_ylabel('')
ax_biassusc.text(-1.25, len(viz_biassusc_final[viz_variants_GCsorted])/2, "inferred bias susceptibility", c='k', verticalalignment='center', horizontalalignment='center', rotation=90, fontsize=6, weight='bold', zorder=100)
for spine in ax_biassusc.spines.values():
    spine.set(visible=True, lw=.4, edgecolor="black")
ax_biassusc.axis('tight')

ax_biassusc.text(-1.25, -470, "A", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, -55, "B", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, 405, "C", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, 615, "D", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)

plt.show()

In [None]:
figSI.savefig(f"figSI-residsfitnesses-pt1_{date}.png", dpi=300)

In [None]:
# stop

In [None]:
assaysets_figSI = ['Baffled', '1.4%Gluc', '1.6%Gluc', '1.8%Gluc', '0.5%Raf', '1.5%Suc1%Raf'] #, '0.2MKCl', '0.5MKCl', 'M3']
# assaysets_figSI = list(set([col.split('-')[0] for col in viz_fitnesses_final['all'].columns if 'EC' not in col.split('-')[0]]))[::-1]
assaysets_figSI

In [None]:
viz_assays = np.concatenate([[col for col in viz_fitnesses_orig['Diploid'].columns if assayset in col] for assayset in assaysets_figSI])
viz_assays

In [None]:
viz_samples = np.concatenate([[col for col in viz_resids_orig.columns if assayset in col] for assayset in assaysets_figSI])
viz_samples
print(len(viz_samples))

In [None]:
figsize = (7.0, 7.0) 
figSI = plt.figure(constrained_layout=True, figsize=figsize, dpi=200) # 250

gridsize = (int(figsize[0]*10), int(figsize[1]*10))
gs = figSI.add_gridspec(gridsize[1], gridsize[0], wspace=0.0, hspace=0.0, width_ratios=[1]*gridsize[1], height_ratios=[1]*gridsize[0])


grid = {'assaysetlabel': {'y0': 0,  'yf': 2},
        'residsorig':    {'y0': 3,  'yf': 22},
        'biasprev':      {'y0': 26, 'yf': 27},
        'biassusc':      {'y0': 28, 'yf': 47},
        'residsfinal':   {'y0': 28, 'yf': 47},
        'fitnessdistns': {'y0': 50, 'yf': 57},
        'fitnessdeltas': {'y0': 61, 'yf': gridsize[1]}}

# ax_wut = figSI.add_subplot(gs[0:1, 0:1])
# ax_wut2 = figSI.add_subplot(gs[1:2, 1:2])

ax_biassusc    = figSI.add_subplot(gs[grid['biassusc']['y0']:grid['biassusc']['yf'], 0:1])
 

    

# ax_cbar_biasorig    = figSI.add_subplot(gs[grid['residsorig']['y0']:13, -1:])
ax_cbar_resids  = figSI.add_subplot(gs[16:grid['residsorig']['yf'], -2:-1])
ax_cbar_bias   = figSI.add_subplot(gs[grid['residsfinal']['y0']:34, -2:-1])
# ax_cbar_residsfinal = figSI.add_subplot(gs[39:grid['residsfinal']['yf'], -1:])


for a, assayset in enumerate(assaysets_figSI):
    print(assayset)
    
    # if(a != 0 and a != 5):
    #     continue
    
    assays_figSI_aset  = [col for col in viz_fitnesses_final['all'].columns if assayset in col]
    samples_figSI_aset = [col for col in viz_resids_final.columns if assayset in col][1:]
    print(assays_figSI_aset)
    print(len(samples_figSI_aset))
    
    col_w = 10
    col_x = 2 + (a*col_w)+(a*1)
    
    print(f"{col_x}:{col_x+col_w}")
    
    ax_assaysetlabel = figSI.add_subplot(gs[grid['assaysetlabel']['y0']:grid['assaysetlabel']['yf'], col_x:col_x+col_w])
    ax_residsorig    = figSI.add_subplot(gs[grid['residsorig']['y0']:grid['residsorig']['yf'], col_x:col_x+col_w])
    ax_biasprev      = figSI.add_subplot(gs[grid['biasprev']['y0']:grid['biasprev']['yf'], col_x:col_x+col_w])
    ax_residsfinal   = figSI.add_subplot(gs[grid['residsfinal']['y0']:grid['residsfinal']['yf'], col_x:col_x+col_w])
    ax_fitnessdistns = figSI.add_subplot(gs[grid['fitnessdistns']['y0']:grid['fitnessdistns']['yf'], col_x:col_x+col_w])
    ax_fitnessdeltas = figSI.add_subplot(gs[grid['fitnessdeltas']['y0']:grid['fitnessdeltas']['yf']:, col_x:col_x+col_w])
    
    #--------------------
    
    ax_assaysetlabel.text(0.5, 0.5, assay_relabels[f"{assayset}"], c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)    
    ax_assaysetlabel.set_facecolor('#eee')
    ax_assaysetlabel.set_xticks([])
    ax_assaysetlabel.set_yticks([])
    for spine in ax_assaysetlabel.spines.values():
        spine.set(visible=True, lw=0, edgecolor="white")
    
    #--------------------
    
    df_residsorig = viz_resids_orig.loc[viz_variants_GCsorted, samples_figSI_aset]
    if(len(samples_figSI_aset) < 14):
        df_residsorig = df_residsorig.reindex(columns=np.concatenate([df_residsorig.columns, ['empty']*int(14-len(samples_figSI_aset))]))
    hm_residsorig = sns.heatmap(ax=ax_residsorig, data=df_residsorig, cmap=residCmapW, center=0, vmin=-1, vmax=1,  #.shift(-1, axis=1)
                                cbar=(a==0), cbar_ax=ax_cbar_resids, cbar_kws={'label':"residuals", 'drawedges':False, 'ticks':[-1, 1], 'aspect': 20})
    
    ax_residsorig.set_xticks([])
    ax_residsorig.set_yticks([])
    
    ax_residsorig.text(2.75 if len(samples_figSI_aset) == 14 else 1.75, len(viz_variants_GCsorted)+7, f"rep1", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R2' in assay for assay in assays_figSI_aset)):
        ax_residsorig.text(7, len(viz_variants_GCsorted)+7, f"rep2", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R3' in assay for assay in assays_figSI_aset)):
        ax_residsorig.text(12.25, len(viz_variants_GCsorted)+7, f"rep3", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    
    if(a == 5):
        ax_residsorig.text(15, 0, "observed residuals", c='k', verticalalignment='top', horizontalalignment='left', rotation=90, fontsize=6, weight='bold', zorder=100)    
    
    #--------------------
    
    df_biasprev = viz_biasprev_final[samples_figSI_aset]
    if(len(samples_figSI_aset) < 14):
        df_biasprev = df_biasprev.reindex(columns=np.concatenate([df_biasprev.columns, ['empty']*int(14-len(samples_figSI_aset))]))
    hm_biasprev = sns.heatmap(ax=ax_biasprev, data=df_biasprev, cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, square=False, linewidths=0.0, linecolor='w',
                              cbar=(a==0), cbar_ax=ax_cbar_bias, cbar_kws={'label':"bias", 'drawedges':False, 'ticks':[-1, 1]})
    
    ax_biasprev.set_xticks([])
    ax_biasprev.set_yticks([])
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(0, 0), width=15, height=0.4, fc='w', ec='w', lw=0.0, clip_on=False, zorder=99))
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(0*(14+1)/3+0.02, 0.405), width=((14+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    if(any('R2' in assay for assay in assays_figSI_aset)):
        ax_biasprev.add_patch(mpatches.Rectangle(xy=(1*(14+1)/3+0.02, 0.405), width=((14+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    if(any('R3' in assay for assay in assays_figSI_aset)):
        ax_biasprev.add_patch(mpatches.Rectangle(xy=(2*(14+1)/3+0.02, 0.405), width=((14+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))    
    
    if(a == 3):
        ax_biasprev.text(0, -0.7, "inferred bias prevalence", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, weight='bold', zorder=100)
    
    #--------------------
    
    df_residsfinal = viz_resids_final.loc[viz_variants_GCsorted, samples_figSI_aset]
    if(len(samples_figSI_aset) < 14):
        df_residsfinal = df_residsfinal.reindex(columns=np.concatenate([df_residsfinal.columns, ['empty']*int(14-len(samples_figSI_aset))]))
    hm_residsfinal = sns.heatmap(ax=ax_residsfinal, data=df_residsfinal, cmap=residCmapW, center=0, vmin=-1, vmax=1, cbar=False)
    
    ax_residsfinal.set_xticks([])
    ax_residsfinal.set_yticks([])
    
    ax_residsfinal.text(2.75 if len(samples_figSI_aset) == 14 else 1.75, len(viz_variants_GCsorted)+7, f"rep1", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R2' in assay for assay in assays_figSI_aset)):
        ax_residsfinal.text(7, len(viz_variants_GCsorted)+7, f"rep2", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R3' in assay for assay in assays_figSI_aset)):
        ax_residsfinal.text(12.25, len(viz_variants_GCsorted)+7, f"rep3", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    
    if(a == 5):
        ax_residsfinal.text(15, len(viz_biassusc_final[viz_variants_GCsorted]), "bias-adjusted residuals", c='k', verticalalignment='bottom', horizontalalignment='left', rotation=90, fontsize=6, weight='bold', zorder=100)
    
    #--------------------
    
    boxplot_fitnesses = []
    
    box_positions = [0.8, 1.2, 1.8, 2.2, 2.8, 3.2]
    
    for g, group in enumerate(vizGroups_fig4B):
        viz_fitnesses_orig_g = np.array([])
        viz_fitnesses_final_g = np.array([])
        for ai, assay in enumerate(assays_figSI_aset):
            viz_fitnesses_orig_g  = np.hstack([viz_fitnesses_orig_g, viz_fitnesses_orig[group].loc[:, assay].values])
            viz_fitnesses_final_g = np.hstack([viz_fitnesses_final_g, viz_fitnesses_final[group].loc[:, assay].values])
            
        pvalue = scipy.stats.levene(viz_fitnesses_orig_g, viz_fitnesses_final_g).pvalue
        if(pvalue < 0.05):
            ax_fitnessdistns.text(np.percentile(viz_fitnesses_final_g, 99)+(0.125), box_positions[::-1][2*g+1]-0.15, "*", verticalalignment='center', horizontalalignment='center', fontsize=6, color=groups_colors[group])
            
        boxplot_fitnesses.append(viz_fitnesses_orig_g)
        boxplot_fitnesses.append(viz_fitnesses_final_g)
    
    bplot = ax_fitnessdistns.boxplot(boxplot_fitnesses[::-1], vert=False, positions=box_positions, widths=0.25, showcaps=False, showfliers=False, patch_artist=True,
                                boxprops=dict(linewidth=0.66, facecolor='r'), whiskerprops=dict(linewidth=0.66), medianprops=dict(linewidth=0.66))
    
    boxshadealpha = 0.5
    for i, patch in enumerate(bplot['boxes'][::-1]):
        patch.set_facecolor(list(groups_colors[vizGroups_fig4B[int(i/2)]])+[boxshadealpha*(i%2)])
        patch.set_edgecolor(groups_colors[vizGroups_fig4B[int(i/2)]])
    for i, patch in enumerate(bplot['whiskers'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/4)]])
    for i, patch in enumerate(bplot['medians'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/2)]])
    
    ax_fitnessdistns.set_xlim((-1.05, 1.05))
    ax_fitnessdistns.set_xticks([-1, 0, 1])
    ax_fitnessdistns.set_xticklabels([-1, 0, 1], fontsize=5)
    ax_fitnessdistns.set_xlabel('')
    ax_fitnessdistns.set_yticks(box_positions if a==0 else []) # if c!=0 else [1, 2, 3])
    ax_fitnessdistns.set_yticklabels(['pre', 'post', 'pre', 'post', 'pre', 'post'][::-1] if a==0 else [], fontsize=5) 
    ax_fitnessdistns.set_ylabel('')    
    ax_fitnessdistns.tick_params(width=0.5, length=1.33, pad=2)
    
    ax_fitnessdistns.set_facecolor('#fff')
    
    for spine in ax_fitnessdistns.spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    

    
    if(a == 3):
        ax_fitnessdistns.text(-0.6, -0.835, "fitness estimates", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)

    #--------------------
    
    for g, group in enumerate(['Diploid', 'GPB2', 'PDE2']):
        
        viz_fEpsilons_g_allReps = np.concatenate([viz_fEpsilons[group].loc[:, assay].values for assay in assays_figSI_aset])
        viz_fDeltas_g_allReps   = np.concatenate([viz_fDeltas[group].loc[:, assay].values for assay in assays_figSI_aset])
    
        sns.regplot(ax=ax_fitnessdeltas, x=viz_fEpsilons_g_allReps, y=-1*viz_fDeltas_g_allReps, ci=0, 
                    color=groups_colors[group], marker=('o' if 'Diploid' in group else 'o'), label=f"{group}", 
                    scatter_kws={'alpha': 0.33 if group=='Diploid' else 1.0, 's': 1, 'lw': 0, 'zorder': g}, line_kws={'lw': 0.66, 'zorder': g})
        
    # ax_fitnessdeltas.set_xticks([])
    ax_fitnessdeltas.set_xlabel('')
    # ax_fitnessdeltas.set_yticks([])
    ax_fitnessdeltas.set_ylabel('')    
    # ax_fitnessdeltas.get_legend().remove()
    
    ax_fitnessdeltas.set_xlim((-1.05, 1.05))
    ax_fitnessdeltas.set_xticks([-1, 0, 1])
    ax_fitnessdeltas.set_xticklabels([-1, 0, 1], fontsize=5)
    ax_fitnessdeltas.set_ylim((-1.05, 1.05))
    ax_fitnessdeltas.set_yticks([-1, 0, 1] if c==0 else [])
    ax_fitnessdeltas.set_yticklabels([-1, 0, 1] if c==0 else [], fontsize=5)
    
    ax_fitnessdeltas.tick_params(width=0.5, length=1.33, pad=2)
    
    for spine in ax_fitnessdeltas.spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    
    if(a == len(assaysets_figSI)-1):
        ax_fitnessdeltas.text(1.0, 1.39, "$~~~~~$Diploid$~~~~~~$GPB2$~~~~~~$PDE2", verticalalignment='center', horizontalalignment='right', fontsize=4.5, zorder=100, bbox=dict(pad=1.25, fc='w', ec='#ddd', lw=0.5), clip_on=False)
        ax_fitnessdeltas.plot([0.30, 0.50], [1.39, 1.39], c=groups_colors['PDE2'], lw=1, zorder=200, clip_on=False)
        ax_fitnessdeltas.plot([-0.5, -0.3], [1.39, 1.39], c=groups_colors['GPB2'], lw=1, zorder=200, clip_on=False)
        ax_fitnessdeltas.plot([-1.35, -1.15], [1.39, 1.39], c=groups_colors['Diploid'], lw=1, zorder=200, clip_on=False)
            
    if(a == 0):
        ax_fitnessdeltas.set_ylabel("correction ($\Delta f$)", labelpad=0, fontsize=6)        
    if(a == 3):
        ax_fitnessdeltas.text(-1.1, -1.6, "initial fitness misestimation ($\delta f$)", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)
    
    # break

    

ax_cbar_resids.set_yticks([-1, 1])
ax_cbar_resids.set_yticklabels([-1, 1], fontsize=4)
ax_cbar_resids.set_ylabel("residuals", fontsize=5, labelpad=-3)
ax_cbar_resids.tick_params(width=0.5, length=1.33, pad=1)
ax_cbar_bias.set_yticks([-1, 1])
ax_cbar_bias.set_yticklabels([-1, 1], fontsize=4)
ax_cbar_bias.set_ylabel("bias", fontsize=5, labelpad=-3)
ax_cbar_bias.tick_params(width=0.5, length=1.33, pad=1)
for spine in ax_cbar_resids.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")
for spine in ax_cbar_bias.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")



sns.heatmap(ax=ax_biassusc, data=np.atleast_2d(viz_biassusc_final[viz_variants_GCsorted]).T, cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, cbar=False)
ax_biassusc.set_xticks([])
ax_biassusc.set_yticks([])
ax_biassusc.set_xticklabels([])
ax_biassusc.set_yticklabels([])
ax_biassusc.set_xlabel('')
ax_biassusc.set_ylabel('')
ax_biassusc.text(-1.25, len(viz_biassusc_final[viz_variants_GCsorted])/2, "inferred bias susceptibility", c='k', verticalalignment='center', horizontalalignment='center', rotation=90, fontsize=6, weight='bold', zorder=100)
for spine in ax_biassusc.spines.values():
    spine.set(visible=True, lw=.4, edgecolor="black")
ax_biassusc.axis('tight')

ax_biassusc.text(-1.25, -470, "A", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, -55, "B", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, 405, "C", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, 615, "D", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)

plt.show()

In [None]:
figSI.savefig(f"figSI-residsfitnesses-pt2_{date}.png", dpi=300)

In [None]:
# stop

In [None]:
assaysets_figSI = ['0.2MKCl', '0.5MKCl', 'M3'] # ['Baffled', '1.4%Gluc', '1.6%Gluc', '1.8%Gluc', '0.5%Raf', '1.5%Suc1%Raf'] #, 
# assaysets_figSI = list(set([col.split('-')[0] for col in viz_fitnesses_final['all'].columns if 'EC' not in col.split('-')[0]]))[::-1]
assaysets_figSI

In [None]:
viz_assays = np.concatenate([[col for col in viz_fitnesses_orig['Diploid'].columns if assayset in col] for assayset in assaysets_figSI])
viz_assays

In [None]:
viz_samples = np.concatenate([[col for col in viz_resids_orig.columns if assayset in col] for assayset in assaysets_figSI])
viz_samples
print(len(viz_samples))

In [None]:
figsize = (7.0, 7.0) 
figSI = plt.figure(constrained_layout=True, figsize=figsize, dpi=200) # 250

gridsize = (int(figsize[0]*10), int(figsize[1]*10))
gs = figSI.add_gridspec(gridsize[1], gridsize[0], wspace=0.0, hspace=0.0, width_ratios=[1]*gridsize[1], height_ratios=[1]*gridsize[0])


grid = {'assaysetlabel': {'y0': 0,  'yf': 2},
        'residsorig':    {'y0': 3,  'yf': 22},
        'biasprev':      {'y0': 26, 'yf': 27},
        'biassusc':      {'y0': 28, 'yf': 47},
        'residsfinal':   {'y0': 28, 'yf': 47},
        'fitnessdistns': {'y0': 50, 'yf': 57},
        'fitnessdeltas': {'y0': 61, 'yf': gridsize[1]}}

# ax_wut = figSI.add_subplot(gs[0:1, 0:1])
# ax_wut2 = figSI.add_subplot(gs[1:2, 1:2])

ax_biassusc    = figSI.add_subplot(gs[grid['biassusc']['y0']:grid['biassusc']['yf'], 0:1])
 

    

# ax_cbar_biasorig    = figSI.add_subplot(gs[grid['residsorig']['y0']:13, -1:])
ax_cbar_resids  = figSI.add_subplot(gs[16:grid['residsorig']['yf'], 35:36])
ax_cbar_bias   = figSI.add_subplot(gs[grid['residsfinal']['y0']:34, 35:36])
# ax_cbar_residsfinal = figSI.add_subplot(gs[39:grid['residsfinal']['yf'], -1:])


for a, assayset in enumerate(assaysets_figSI):
    print(assayset)
    
    # if(a != 0 and a != 5):
    #     continue
    
    assays_figSI_aset  = [col for col in viz_fitnesses_final['all'].columns if assayset in col]
    samples_figSI_aset = [col for col in viz_resids_final.columns if assayset in col][1:]
    print(assays_figSI_aset)
    print(len(samples_figSI_aset))
    
    col_w = 10
    col_x = 2 + (a*col_w)+(a*1)
    
    print(f"{col_x}:{col_x+col_w}")
    
    ax_assaysetlabel = figSI.add_subplot(gs[grid['assaysetlabel']['y0']:grid['assaysetlabel']['yf'], col_x:col_x+col_w])
    ax_residsorig    = figSI.add_subplot(gs[grid['residsorig']['y0']:grid['residsorig']['yf'], col_x:col_x+col_w])
    ax_biasprev      = figSI.add_subplot(gs[grid['biasprev']['y0']:grid['biasprev']['yf'], col_x:col_x+col_w])
    ax_residsfinal   = figSI.add_subplot(gs[grid['residsfinal']['y0']:grid['residsfinal']['yf'], col_x:col_x+col_w])
    ax_fitnessdistns = figSI.add_subplot(gs[grid['fitnessdistns']['y0']:grid['fitnessdistns']['yf'], col_x:col_x+col_w])
    ax_fitnessdeltas = figSI.add_subplot(gs[grid['fitnessdeltas']['y0']:grid['fitnessdeltas']['yf']:, col_x:col_x+col_w])
    
    #--------------------
    
    ax_assaysetlabel.text(0.5, 0.5, assay_relabels[f"{assayset}"], c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)    
    ax_assaysetlabel.set_facecolor('#eee')
    ax_assaysetlabel.set_xticks([])
    ax_assaysetlabel.set_yticks([])
    for spine in ax_assaysetlabel.spines.values():
        spine.set(visible=True, lw=0, edgecolor="white")
    
    #--------------------
    
    df_residsorig = viz_resids_orig.loc[viz_variants_GCsorted, samples_figSI_aset]
    if(len(samples_figSI_aset) < 14):
        df_residsorig = df_residsorig.reindex(columns=np.concatenate([df_residsorig.columns, ['empty']*int(14-len(samples_figSI_aset))]))
    hm_residsorig = sns.heatmap(ax=ax_residsorig, data=df_residsorig, cmap=residCmapW, center=0, vmin=-1, vmax=1,  #.shift(-1, axis=1)
                                cbar=(a==0), cbar_ax=ax_cbar_resids, cbar_kws={'label':"residuals", 'drawedges':False, 'ticks':[-1, 1], 'aspect': 20})
    
    ax_residsorig.set_xticks([])
    ax_residsorig.set_yticks([])
    
    ax_residsorig.text(2.75 if len(samples_figSI_aset) == 14 else 1.75, len(viz_variants_GCsorted)+7, f"rep1", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R2' in assay for assay in assays_figSI_aset)):
        ax_residsorig.text(7, len(viz_variants_GCsorted)+7, f"rep2", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R3' in assay for assay in assays_figSI_aset)):
        ax_residsorig.text(12.25, len(viz_variants_GCsorted)+7, f"rep3", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R4' in assay for assay in assays_figSI_aset)):
        ax_residsorig.text(17., len(viz_variants_GCsorted)+7, f"rep4", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    
    if(a == 2):
        ax_residsorig.text(21, 0, "observed residuals", c='k', verticalalignment='top', horizontalalignment='left', rotation=90, fontsize=6, weight='bold', zorder=100)    
    
    #--------------------
    
    df_biasprev = viz_biasprev_final[samples_figSI_aset]
    if(len(samples_figSI_aset) < 14):
        df_biasprev = df_biasprev.reindex(columns=np.concatenate([df_biasprev.columns, ['empty']*int(14-len(samples_figSI_aset))]))
    hm_biasprev = sns.heatmap(ax=ax_biasprev, data=df_biasprev, cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, square=False, linewidths=0.0, linecolor='w',
                              cbar=(a==0), cbar_ax=ax_cbar_bias, cbar_kws={'label':"bias", 'drawedges':False, 'ticks':[-1, 1]})
    
    ax_biasprev.set_xticks([])
    ax_biasprev.set_yticks([])
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(0, 0), width=len(samples_figSI_aset)+1, height=0.4, fc='w', ec='w', lw=0.0, clip_on=False, zorder=99))
    ax_biasprev.add_patch(mpatches.Rectangle(xy=(0*(14+1)/3+0.02, 0.405), width=((14+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    if(any('R2' in assay for assay in assays_figSI_aset)):
        ax_biasprev.add_patch(mpatches.Rectangle(xy=(1*(14+1)/3+0.02, 0.405), width=((14+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))
    if(any('R3' in assay for assay in assays_figSI_aset)):
        ax_biasprev.add_patch(mpatches.Rectangle(xy=(2*(14+1)/3+0.02, 0.405), width=((14+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))    
    if(any('R4' in assay for assay in assays_figSI_aset)):
        ax_biasprev.add_patch(mpatches.Rectangle(xy=(3*(14+1)/3+0.02, 0.405), width=((14+1)/3 - 1)*0.97, height=1-0.405, fc='#ffffff00', ec='k', lw=0.4, clip_on=False, zorder=99))    
    
    if(a == 1):
        ax_biasprev.text(7., -0.7, "inferred bias prevalence", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, weight='bold', zorder=100)
    
    #--------------------
    
    df_residsfinal = viz_resids_final.loc[viz_variants_GCsorted, samples_figSI_aset]
    if(len(samples_figSI_aset) < 14):
        df_residsfinal = df_residsfinal.reindex(columns=np.concatenate([df_residsfinal.columns, ['empty']*int(14-len(samples_figSI_aset))]))
    hm_residsfinal = sns.heatmap(ax=ax_residsfinal, data=df_residsfinal, cmap=residCmapW, center=0, vmin=-1, vmax=1, cbar=False)
    
    ax_residsfinal.set_xticks([])
    ax_residsfinal.set_yticks([])
    
    ax_residsfinal.text(2.75 if len(samples_figSI_aset) == 14 else 1.75, len(viz_variants_GCsorted)+7, f"rep1", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R2' in assay for assay in assays_figSI_aset)):
        ax_residsfinal.text(7, len(viz_variants_GCsorted)+7, f"rep2", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R3' in assay for assay in assays_figSI_aset)):
        ax_residsfinal.text(12.25, len(viz_variants_GCsorted)+7, f"rep3", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    if(any('R4' in assay for assay in assays_figSI_aset)):
        ax_residsfinal.text(17., len(viz_variants_GCsorted)+7, f"rep4", c='#999', verticalalignment='top', horizontalalignment='center', rotation=0, fontsize=5, zorder=100)
    
    if(a == 2):
        ax_residsfinal.text(21, len(viz_biassusc_final[viz_variants_GCsorted]), "bias-adjusted residuals", c='k', verticalalignment='bottom', horizontalalignment='left', rotation=90, fontsize=6, weight='bold', zorder=100)
    
    #--------------------
    
    boxplot_fitnesses = []
    
    box_positions = [0.8, 1.2, 1.8, 2.2, 2.8, 3.2]
    
    for g, group in enumerate(vizGroups_fig4B):
        viz_fitnesses_orig_g = np.array([])
        viz_fitnesses_final_g = np.array([])
        for ai, assay in enumerate(assays_figSI_aset):
            viz_fitnesses_orig_g  = np.hstack([viz_fitnesses_orig_g, viz_fitnesses_orig[group].loc[:, assay].values])
            viz_fitnesses_final_g = np.hstack([viz_fitnesses_final_g, viz_fitnesses_final[group].loc[:, assay].values])
            
        pvalue = scipy.stats.levene(viz_fitnesses_orig_g, viz_fitnesses_final_g).pvalue
        if(pvalue < 0.05):
            ax_fitnessdistns.text(np.percentile(viz_fitnesses_final_g, 99)+(0.125), box_positions[::-1][2*g+1]-0.15, "*", verticalalignment='center', horizontalalignment='center', fontsize=6, color=groups_colors[group])
            
        boxplot_fitnesses.append(viz_fitnesses_orig_g)
        boxplot_fitnesses.append(viz_fitnesses_final_g)
    
    bplot = ax_fitnessdistns.boxplot(boxplot_fitnesses[::-1], vert=False, positions=box_positions, widths=0.25, showcaps=False, showfliers=False, patch_artist=True,
                                boxprops=dict(linewidth=0.66, facecolor='r'), whiskerprops=dict(linewidth=0.66), medianprops=dict(linewidth=0.66))
    
    boxshadealpha = 0.5
    for i, patch in enumerate(bplot['boxes'][::-1]):
        patch.set_facecolor(list(groups_colors[vizGroups_fig4B[int(i/2)]])+[boxshadealpha*(i%2)])
        patch.set_edgecolor(groups_colors[vizGroups_fig4B[int(i/2)]])
    for i, patch in enumerate(bplot['whiskers'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/4)]])
    for i, patch in enumerate(bplot['medians'][::-1]):
        patch.set_color(groups_colors[vizGroups_fig4B[int(i/2)]])
    
    ax_fitnessdistns.set_xlim((-1.05, 1.05))
    ax_fitnessdistns.set_xticks([-1, 0, 1])
    ax_fitnessdistns.set_xticklabels([-1, 0, 1], fontsize=5)
    ax_fitnessdistns.set_xlabel('')
    ax_fitnessdistns.set_yticks(box_positions if a==0 else []) # if c!=0 else [1, 2, 3])
    ax_fitnessdistns.set_yticklabels(['pre', 'post', 'pre', 'post', 'pre', 'post'][::-1] if a==0 else [], fontsize=5) 
    ax_fitnessdistns.set_ylabel('')    
    ax_fitnessdistns.tick_params(width=0.5, length=1.33, pad=2)
    
    ax_fitnessdistns.set_facecolor('#fff')
    
    for spine in ax_fitnessdistns.spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    

    
    if(a == 1):
        ax_fitnessdistns.text(-0., -0.835, "fitness estimates", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)

    #--------------------
    
    for g, group in enumerate(['Diploid', 'GPB2', 'PDE2']):
        
        viz_fEpsilons_g_allReps = np.concatenate([viz_fEpsilons[group].loc[:, assay].values for assay in assays_figSI_aset])
        viz_fDeltas_g_allReps   = np.concatenate([viz_fDeltas[group].loc[:, assay].values for assay in assays_figSI_aset])
    
        sns.regplot(ax=ax_fitnessdeltas, x=viz_fEpsilons_g_allReps, y=-1*viz_fDeltas_g_allReps, ci=0, 
                    color=groups_colors[group], marker=('o' if 'Diploid' in group else 'o'), label=f"{group}", 
                    scatter_kws={'alpha': 0.33 if group=='Diploid' else 1.0, 's': 1, 'lw': 0, 'zorder': g}, line_kws={'lw': 0.66, 'zorder': g})
        
    # ax_fitnessdeltas.set_xticks([])
    ax_fitnessdeltas.set_xlabel('')
    # ax_fitnessdeltas.set_yticks([])
    ax_fitnessdeltas.set_ylabel('')    
    # ax_fitnessdeltas.get_legend().remove()
    
    ax_fitnessdeltas.set_xlim((-1.05, 1.05))
    ax_fitnessdeltas.set_xticks([-1, 0, 1])
    ax_fitnessdeltas.set_xticklabels([-1, 0, 1], fontsize=5)
    ax_fitnessdeltas.set_ylim((-1.05, 1.05))
    ax_fitnessdeltas.set_yticks([-1, 0, 1] if c==0 else [])
    ax_fitnessdeltas.set_yticklabels([-1, 0, 1] if c==0 else [], fontsize=5)
    
    ax_fitnessdeltas.tick_params(width=0.5, length=1.33, pad=2)
    
    for spine in ax_fitnessdeltas.spines.values():
        spine.set_edgecolor('#000')
        spine.set_linewidth(0.5)
    
    if(a == len(assaysets_figSI)-1):
        ax_fitnessdeltas.text(1.0, 1.39, "$~~~~~$Diploid$~~~~~~$GPB2$~~~~~~$PDE2", verticalalignment='center', horizontalalignment='right', fontsize=4.5, zorder=100, bbox=dict(pad=1.25, fc='w', ec='#ddd', lw=0.5), clip_on=False)
        ax_fitnessdeltas.plot([0.30, 0.50], [1.39, 1.39], c=groups_colors['PDE2'], lw=1, zorder=200, clip_on=False)
        ax_fitnessdeltas.plot([-0.5, -0.3], [1.39, 1.39], c=groups_colors['GPB2'], lw=1, zorder=200, clip_on=False)
        ax_fitnessdeltas.plot([-1.35, -1.15], [1.39, 1.39], c=groups_colors['Diploid'], lw=1, zorder=200, clip_on=False)
            
    if(a == 0):
        ax_fitnessdeltas.set_ylabel("correction ($\Delta f$)", labelpad=0, fontsize=6)        
    if(a == 1):
        ax_fitnessdeltas.text(0, -1.6, "initial fitness misestimation ($\delta f$)", c='k', verticalalignment='center', horizontalalignment='center', rotation=0, fontsize=6, zorder=100)
    
    # break

    

ax_cbar_resids.set_yticks([-1, 1])
ax_cbar_resids.set_yticklabels([-1, 1], fontsize=4)
ax_cbar_resids.set_ylabel("residuals", fontsize=5, labelpad=-3)
ax_cbar_resids.tick_params(width=0.5, length=1.33, pad=1)
ax_cbar_bias.set_yticks([-1, 1])
ax_cbar_bias.set_yticklabels([-1, 1], fontsize=4)
ax_cbar_bias.set_ylabel("bias", fontsize=5, labelpad=-3)
ax_cbar_bias.tick_params(width=0.5, length=1.33, pad=1)
for spine in ax_cbar_resids.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")
for spine in ax_cbar_bias.spines.values():
    spine.set(visible=True, lw=.3, edgecolor="black")



sns.heatmap(ax=ax_biassusc, data=np.atleast_2d(viz_biassusc_final[viz_variants_GCsorted]).T, cmap=biasCmapW, center=0, vmin=-1.5, vmax=1.5, cbar=False)
ax_biassusc.set_xticks([])
ax_biassusc.set_yticks([])
ax_biassusc.set_xticklabels([])
ax_biassusc.set_yticklabels([])
ax_biassusc.set_xlabel('')
ax_biassusc.set_ylabel('')
ax_biassusc.text(-1.25, len(viz_biassusc_final[viz_variants_GCsorted])/2, "inferred bias susceptibility", c='k', verticalalignment='center', horizontalalignment='center', rotation=90, fontsize=6, weight='bold', zorder=100)
for spine in ax_biassusc.spines.values():
    spine.set(visible=True, lw=.4, edgecolor="black")
ax_biassusc.axis('tight')

ax_biassusc.text(-1.25, -470, "A", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, -55, "B", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, 405, "C", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)
ax_biassusc.text(-1.25, 615, "D", verticalalignment='top', horizontalalignment='center', fontsize=7, weight='bold', rotation=0)

plt.show()

In [None]:
figSI.savefig(f"figSI-residsfitnesses-pt3_{date}.png", dpi=300)

In [None]:
# stop

In [None]:
viz_assays = [  'EC23-R1', 'EC23-R2', 'EC23-R3', 
                'EC21-R1', 'EC21-R2', 'EC21-R3',
                'EC20-R1', 'EC20-R2', 'EC20-R3', 
                'EC18-R1', 'EC18-R2', 'EC18-R3',
                'EC13-R1', 'EC13-R2', 'EC13-R3', 
                'EC3-R1', 'EC3-R2', 'EC3-R3',
                'Baffled-R1', 'Baffled-R2',
                '1.4%Gluc-R1', '1.4%Gluc-R2',
                '1.6%Gluc-R1', '1.6%Gluc-R2',
                '1.8%Gluc-R1', '1.8%Gluc-R2', 
                '0.5%Raf-R1', '0.5%Raf-R2',
                '1.5%Suc1%Raf-R1', 
                '0.2MKCl-R1',
                '0.5MKCl-R1',
                'M3-R1', 'M3-R2', 'M3-R3', 'M3-R4' ]
viz_assays

In [None]:
figSI, ax = plt.subplots(7, 5, sharex=True, sharey=True, figsize=(5.0, 7.7), dpi=200)

biassusc_allVariants = debiaser.variantsInfo.loc[:, 'bias_susceptibility'].values
    
controlVariants = debiaser.variantsInfo[debiaser.variantsInfo['control_set'] == True].index.values
    
for a, assay in enumerate(viz_assays):
    # print(a, a//5, a%5, assay)
    r = a//5
    c = a%5
    
    assayInfo  = debiaser.samplesInfo.loc[(debiaser.samplesInfo['assay'] == assay) & (~debiaser.samplesInfo['timept'].isin(debiaser.cfg['exclude_timepts']))].sort_values(by='timept')
    samples_a  = assayInfo['sample'].unique()
    timepts_a  = assayInfo['timept'].values
    biasprev_a = assayInfo['bias_prevalence'].values

    trustworthyVariants_a = np.where(debiaser.trustworthy[assay]==True)[0]
    trustworthyControlVariants_a = list(set(controlVariants).intersection(trustworthyVariants_a))

    biassusc_a = biassusc_allVariants[trustworthyControlVariants_a]

    fitnesses_controlSet_a  = debiaser.log['stage1-iter5']['fitnesses'].loc[trustworthyControlVariants_a, assay].values

    #----

    delta_f_a = fitnesses_controlSet_a - fitnesses_controlSet_a.mean() # fitness misestimates among control set

    lambda_a, v_intcpt_a = debiaser.linear_regression(x=biassusc_a, y=delta_f_a)     

    #---------------------------

    # slope_udf, intercept_udf = debiaser.linear_regression(x=biassusc_a, y=delta_f_a)
    # print("slope_udf:", slope_udf, "intercept_udf:", intercept_udf)
    ax[a//5, a%5].scatter(biassusc_a, delta_f_a, s=4, alpha=0.5, color=groups_colors['Diploid'], linewidth=0)
    lambda_color = biasCmapG(biasCmapG_norm(-lambda_a))
    ax[a//5, a%5].plot(np.sort(biassusc_a), lambda_a*np.sort(biassusc_a) + v_intcpt_a, color=lambda_color, label="$\hat{\lambda} = $"+f"{round(lambda_a, 2)}")

    leg = ax[a//5, a%5].legend(handlelength=0, handletextpad=0, fancybox=True, prop={'size': 5}, loc='lower left')
    for item in leg.legendHandles:
        item.set_visible(False)
    leg.get_frame().set_facecolor(lambda_color)
    leg.get_frame().set_edgecolor(None)

    ax[r, c].set_xticks([-1, 0, 1])
    ax[r, c].set_xticklabels([-1, 0, 1], fontsize=5)
    ax[r, c].set_yticks([-0.5, 0, 0.5])
    ax[r, c].set_yticklabels(["−0.5", "0.0", "0.5"], fontsize=5)
    
    if(r == 3 and c == 0):
        ax[r, c].set_ylabel("fitness misestimates ($\delta\!f$)", fontsize=6)
    if(r == 6 and c == 2):
        ax[r, c].set_xlabel("inferred bias susceptibility ($\hat{u}$)", fontsize=6)
        
    ax[r, c].set_title(assay_relabels[assay.split('-')[0]]+" rep"+assay[-1], fontsize=6)
    
    
figSI.tight_layout()

plt.show()
    
    

In [None]:
figSI.savefig(f"figSI-stage2lambdas_{date}.png", dpi=300)

In [None]:
# stop

In [None]:
figSI, ax = plt.subplots(1, 3, sharex=True, sharey=False, figsize=(7.0, 2.25), dpi=200)

biassusc_allVariants = debiaser.variantsInfo.loc[:, 'bias_susceptibility'].values
    
controlVariants = debiaser.variantsInfo[debiaser.variantsInfo['control_set'] == True].index.values
    
for a, assay in enumerate(['EC21-R2']):
    
    assayInfo  = debiaser.samplesInfo.loc[(debiaser.samplesInfo['assay'] == assay) & (~debiaser.samplesInfo['timept'].isin(debiaser.cfg['exclude_timepts']))].sort_values(by='timept')
    samples_a  = assayInfo['sample'].unique()
    timepts_a  = assayInfo['timept'].values
    biasprev_a = assayInfo['bias_prevalence'].values
    baseline_a = assayInfo['normalization_factor'].values
    
    trustworthyVariants_a = np.where(debiaser.trustworthy[assay]==True)[0]
    trustworthyControlVariants_a = list(set(controlVariants).intersection(trustworthyVariants_a))
    np.random.seed(285) # 285 568 42 : 7 > 3/5/6 > 2 > 1
    viz_variants = np.random.choice(trustworthyControlVariants_a, size=8)
    print(viz_variants)
    
    biassusc_a = biassusc_allVariants[viz_variants]
    
    #--------------------
    
    counts_raw = debiaser.counts.loc[viz_variants, samples_a].values
    
    counts_obs = counts_raw/baseline_a
    
    counts_adj = (counts_raw/baseline_a)*np.exp(biassusc_a[:, np.newaxis] * biasprev_a[np.newaxis, :])

    #--------------------
    
    palette = sns.color_palette('hls', len(viz_variants))
    
    for i, counts_raw_i in enumerate(counts_raw):
        
        ax[0].set_title("Raw")
        
        ax[0].plot(timepts_a, np.log(counts_raw_i), color=palette[i], ls='--', lw=0.75, alpha=0.4)
        
        ax[0].scatter(timepts_a, np.log(counts_raw_i), color=palette[i], s=10, edgecolor='k', linewidth=0.5)
        
        ax[0].set_ylabel('Raw log-counts: $\\log(\\mathcal{C}_{i,t})$', labelpad=0)
        ax[0].set_xlabel('samples ($t$)')        
    
    for i, counts_obs_i in enumerate(counts_obs):
        
        ax[1].set_title("\"Observed\"")
        
        ax[1].plot(timepts_a, np.log(counts_obs_i), color=palette[i], ls='--', lw=0.75, alpha=0.4)
        
        ax[1].scatter(timepts_a, np.log(counts_obs_i), color=palette[i], s=10, edgecolor='k', linewidth=0.5)
        
        slope_i, intcpt_i = debiaser.linear_regression(x=timepts_a, y=np.log(counts_obs_i))
        ax[1].plot(timepts_a, slope_i*timepts_a + intcpt_i, color=palette[i], ls='-', lw=1.1)
        
        ax[1].set_ylabel('Observed log-counts: $\\log({C}_{i,t})$', labelpad=0)
        ax[1].set_xlabel('samples ($t$)')        
        
    for i, counts_adj_i in enumerate(counts_adj):
        
        ax[2].set_title("Bias-adjusted")
        
        ax[2].plot(timepts_a, np.log(counts_adj_i), color=palette[i], ls='--', lw=0.75, alpha=0.4)
        
        ax[2].scatter(timepts_a, np.log(counts_adj_i), color=palette[i], s=10, edgecolor='k', linewidth=0.5)
        
        slope_i, intcpt_i = debiaser.linear_regression(x=timepts_a, y=np.log(counts_adj_i))
        ax[2].plot(timepts_a, slope_i*timepts_a + intcpt_i, color=palette[i], ls='-', lw=1.1)
        
        ax[2].set_ylabel('Bias-adjusted log-counts: $\\log({A}_{i,t})$', labelpad=0)
        ax[2].set_xlabel('samples ($t$)')
    
    
figSI.tight_layout()

plt.show()
    
    

In [None]:
figSI.savefig(f"figSI-countsrawobsadj_{date}.png", dpi=300)

In [None]:
# stop

In [None]:

biassusc_allVariants = debiaser.variantsInfo.loc[:, 'bias_susceptibility'].values
    
controlVariants = debiaser.variantsInfo[debiaser.variantsInfo['control_set'] == True].index.values
    
seeds         = []
slopevars_obs = []
slopevars_adj = []
residsmeans_obs = []
residsmeans_adj = []

for a, assay in enumerate(['EC21-R2']):
    
    for seed in range(1000):
        
        seeds.append(seed)
        
        assayInfo  = debiaser.samplesInfo.loc[(debiaser.samplesInfo['assay'] == assay) & (~debiaser.samplesInfo['timept'].isin(debiaser.cfg['exclude_timepts']))].sort_values(by='timept')
        samples_a  = assayInfo['sample'].unique()
        timepts_a  = assayInfo['timept'].values
        biasprev_a = assayInfo['bias_prevalence'].values
        baseline_a = assayInfo['normalization_factor'].values

        trustworthyVariants_a = np.where(debiaser.trustworthy[assay]==True)[0]
        trustworthyControlVariants_a = list(set(controlVariants).intersection(trustworthyVariants_a))
        np.random.seed(seed) # 7 > 3/5/6 > 2 > 1
        viz_variants = np.random.choice(trustworthyControlVariants_a, size=8)
        
        # print(seed, viz_variants)

        biassusc_a = biassusc_allVariants[viz_variants]

        #--------------------

        counts_raw = debiaser.counts.loc[viz_variants, samples_a].values

        counts_obs = counts_raw/baseline_a

        counts_adj = (counts_raw/baseline_a)*np.exp(biassusc_a[:, np.newaxis] * biasprev_a[np.newaxis, :])

        #--------------------
        
        slopes_obs = []
        slopes_adj = []
        
        resids_obs = []
        resids_adj = []

        for i, counts_obs_i in enumerate(counts_obs):

            slope_i, intcpt_i = debiaser.linear_regression(x=timepts_a, y=np.log(counts_obs_i))
            slopes_obs.append(slope_i)
            
            resids = np.log(counts_obs_i) - (slope_i*timepts_a + intcpt_i)
            resids_obs.append(resids)

        for i, counts_adj_i in enumerate(counts_adj):

            slope_i, intcpt_i = debiaser.linear_regression(x=timepts_a, y=np.log(counts_adj_i))
            slopes_adj.append(slope_i)
            
            resids = np.log(counts_adj_i) - (slope_i*timepts_a + intcpt_i)
            resids_adj.append(resids)
            
        slopevars_obs.append(np.var(slopes_obs))
        slopevars_adj.append(np.var(slopes_adj))
        
        residsmeans_obs.append(np.mean(np.abs(np.array(resids_obs).ravel())))
        residsmeans_adj.append(np.mean(np.abs(np.array(resids_adj).ravel())))


    
    

In [None]:
slopevars_obs = np.array(slopevars_obs)
slopevars_adj = np.array(slopevars_adj)

In [None]:
# slopevars_obs - slopevars_adj

In [None]:
np.argmax(slopevars_obs - slopevars_adj)

In [None]:
(slopevars_obs - slopevars_adj)[np.argmax(slopevars_obs - slopevars_adj)]

In [None]:
np.argsort(slopevars_obs - slopevars_adj)[::-1]

In [None]:
residsmeans_obs = np.array(residsmeans_obs)
residsmeans_adj = np.array(residsmeans_adj)

In [None]:
# residsmeans_obs - residsmeans_adj

In [None]:
np.argmax(residsmeans_obs - residsmeans_adj)

In [None]:
np.argsort(residsmeans_obs - residsmeans_adj)[::-1]

In [None]:
debiaser.variantsInfo

In [None]:
debiaser.trustworthy.loc[debiaser.variantsInfo['neutral_group'].notnull()]

In [None]:
debiaser.trustworthy.loc[debiaser.variantsInfo['neutral_group'].notnull(), [col for col in debiaser.trustworthy.columns if 'EC' in col]]

In [None]:
debiaser.trustworthy.loc[debiaser.variantsInfo['neutral_group'].notnull(), [col for col in debiaser.trustworthy.columns if 'EC' in col]].sum()

In [None]:
np.where(debiaser.trustworthy.loc[debiaser.variantsInfo['neutral_group'].notnull(), :].sum(axis=1).values == 1)

In [None]:
(debiaser.trustworthy.loc[debiaser.variantsInfo['neutral_group'].notnull(), :].sum(axis=1) > 0).values

In [None]:
debiaser.residuals

In [None]:
debiaser.fitnesses

In [None]:
debiaser.trustworthy.loc[debiaser.variantsInfo['neutral_group'].notnull(), [col for col in debiaser.trustworthy.columns if 'EC' in col]].sum()

In [None]:
debiaser.variantsInfo[debiaser.variantsInfo['neutral_group'] == 'PDE2']

In [None]:
variants_trustworthyAtLeast1Assay = 

In [None]:
debiaser.trustworthy.loc[:, [col for col in debiaser.trustworthy.columns if 'EC' in col]].sum(axis=1).values

In [None]:
np.where(debiaser.trustworthy.loc[debiaser.variantsInfo['neutral_group'].notnull(), :].sum(axis=1).values == 1)

In [None]:
matlab_resids_orig

In [None]:
matlab_resids_orig

In [None]:
matlabNonNANvariants = matlab_resids_final.loc[:, [s for s in viz_samples if 'TX' not in s and 'T0' not in s]].dropna().index.values
matlabNonNANvariants

In [None]:
(~matlab_resids_final.loc[:, [s for s in viz_samples if 'TX' not in s and 'T0' not in s]].isnull()).sum()

In [None]:
np.where(debiaser.trustworthy.loc[:, [a for a in debiaser.trustworthy.columns if 'bias' not in a]].all(axis=1))#.index.values

In [None]:
debiaser.trustworthy.loc[:, [a for a in debiaser.trustworthy.columns if 'bias' not in a]].sum()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(6, 10))
sns.heatmap(ax=ax[0], data=debiaser.log['final']['residuals'].loc[pythonGoodForVizVariants, [col for col in debiaser.residuals.columns if 'EC20' in col]], cmap=residCmapW, center=0, vmin=-1.0, vmax=1.0, cbar=False)
sns.heatmap(ax=ax[1], data=matlab_resids_final.loc[matlabNonNANvariants, [col for col in matlab_resids_final.columns if 'EC20' in col]], cmap=residCmapW, center=0, vmin=-1.0, vmax=1.0, cbar=False)

In [None]:
len(pythonGoodForVizVariants)

In [None]:
len(matlabNonNANvariants)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(5, 5))
sns.heatmap(ax=ax, data=debiaser.log['observed']['residuals'].loc[variants_fig4A[:], [col for col in debiaser.residuals.columns if 'EC20' in col]], cmap=residCmapW, center=0, vmin=-1.0, vmax=1.0, )

In [None]:
fig, ax = plt.subplots(1,2, figsize=(6, 10))
sns.heatmap(ax=ax[0], data=debiaser.log['final']['residuals'].loc[variants_fig4A[:], [col for col in debiaser.residuals.columns if 'EC20' in col]], cmap=residCmapW, center=0, vmin=-1.0, vmax=1.0, cbar=False)
sns.heatmap(ax=ax[1], data=matlab_counts_final.loc[variants_fig4A[:], [col for col in matlab_resids_final.columns if 'EC20' in col]], cmap=residCmapW, center=0, vmin=-1.0, vmax=1.0, cbar=False)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(5, 5))
sns.heatmap(ax=ax, data=debiaser.log['final']['residuals'].loc[variants_fig4A[:], [col for col in debiaser.residuals.columns if 'EC20' in col]]-debiaser.log['observed']['residuals'].loc[variants_fig4A[:], [col for col in debiaser.residuals.columns if 'EC20' in col]], cmap=residCmapW, center=0, vmin=-1.0, vmax=1.0, )

In [None]:
debiaser.residuals.loc[variants_fig4A[:20], [col for col in debiaser.residuals.columns if 'EC3' in col]]

In [None]:
debiaser.log['final']['residuals'].loc[variants_fig4A[:20], [col for col in debiaser.residuals.columns if 'EC3' in col]] - debiaser.log['observed']['residuals'].loc[variants_fig4A[:20], [col for col in debiaser.residuals.columns if 'EC3' in col]]

In [None]:
debiaser.counts

In [None]:
goodForViz = pd.DataFrame(index=debiaser.trustworthy.index, columns=debiaser.trustworthy.columns)
goodForViz

In [None]:
for assay in debiaser.samplesInfo['assay'].unique():
    print(assay)
    assayInfo = debiaser.samplesInfo.loc[(debiaser.samplesInfo['assay'] == assay) & (~debiaser.samplesInfo['timept'].isin(debiaser.cfg['exclude_timepts']))].sort_values(by='timept')
    samples_a = assayInfo['sample'].unique()
    print(samples_a)
    mean_counts_a = debiaser.counts[samples_a].mean(axis=1).values
    print(mean_counts_a)
    print(mean_counts_a > 25)
    goodForViz.loc[:, assay] = (mean_counts_a > 25)

In [None]:
goodForViz

In [None]:
len(np.where(goodForViz.all(axis=1))[0])

In [None]:
pythonGoodForVizVariants = np.where(goodForViz.all(axis=1))[0]
pythonGoodForVizVariants

In [None]:
for c, ECnum in enumerate(ECs_fig4):
    viz_assays_EC = [col for col in viz_fitnesses_final['all'].columns if f"EC{ECnum}" in col]
    print(viz_assays_EC)
    
    boxplot_fitnesses = []
    
    for g, group in enumerate(vizGroups_fig4B):
        print(group)
        viz_fitnesses_orig_g_a = np.array([])
        viz_fitnesses_final_g_a = np.array([])
        for a, assay in enumerate(viz_assays_EC):
            viz_fitnesses_orig_g_a  = np.hstack([viz_fitnesses_orig_g_a, viz_fitnesses_orig[group].loc[:, assay].values])
            viz_fitnesses_final_g_a = np.hstack([viz_fitnesses_final_g_a, viz_fitnesses_final[group].loc[:, assay].values])
        # print(len(viz_fitnesses_orig_g_a), np.var(viz_fitnesses_orig_g_a))
        # print(len(viz_fitnesses_final_g_a), np.var(viz_fitnesses_final_g_a))
        pvalue = scipy.stats.levene(viz_fitnesses_orig_g_a, viz_fitnesses_final_g_a).pvalue
        print(pvalue, "***********************" if pvalue<0.05 else "")
        boxplot_fitnesses.append(viz_fitnesses_orig_g_a)
        boxplot_fitnesses.append(viz_fitnesses_final_g_a)

In [None]:
viz_biassusc_final[viz_variants_GCsorted]

In [None]:
viz_biasprev_final[samples_fig4A_EC[1:]]

In [None]:
plt.hist(debiaser.samplesInfo['bias_prevalence'].values, bins=len(debiaser.samplesInfo['bias_prevalence'].values))
plt.show()

In [None]:
plt.hist(debiaser.variantsInfo['bias_susceptibility'].values, bins=len(debiaser.variantsInfo['bias_susceptibility'].values))
plt.show()

In [None]:
debiaser.variantsInfo.loc[viz_variants_GCsorted, 'bias_susceptibility'].max()

In [None]:
plt.hist(debiaser.variantsInfo.loc[viz_variants_GCsorted, 'bias_susceptibility'])
plt.show()