# The Continuous Epistasis Model

The following code trains and evaluates the continuous epistasis model. It uses single- and pairwise CRISPRi perturbations to relate gene expression (relative repression) to growth rate. Single-gene expression-growth rate relationships are modeled with two-parameter sigmoidal functions, and gene-gene coupling is accounted for using two coupling constants. The resulting model is assessed by its RMSD relative to a coupling-insensitive Null model. A subsampled model trained on only 20% of the available pairwise CRISPRi data performs similarly to the full model, supporting a sparse sampling strategy as a way to improve throughput in future experiments.

6/28/22 - Ryan Otto

### Import packages and libraries

In [None]:
import numpy as np
import pandas as pd
import math
import pickle
import matplotlib.pyplot as plt
from scipy import stats
from scipy.optimize import least_squares
import seaborn as sns
import plot_defaults
plot_defaults.change_defaults()

### Define relevant variables

In [None]:
date = '220815'
output_path = 'intermediate_files'
optimization_path = 'input_files/optimization_files'
figure_path = 'Figures'
gene_names = ['dapA', 'dapB', 'gdhA', 'gltB', 'folA', 'thyA', 'glyA', 'purN', 'purL']

### Import data

In [None]:
with open(f'{output_path}/220815_df_growth_pool_filt_rescale.pickle', 'rb') as handle:
    growth_df_rescale = pickle.load(handle)
with open(f'{output_path}/220815_df_growth_pool_filt_sem_rescale.pickle', 'rb') as handle:
    sem_df_rescale = pickle.load(handle)
with open(f'{output_path}/220815_repression_mean.pickle', 'rb') as handle:
    qPCR_vals = pickle.load(handle)
with open(f'{output_path}/220815_repression_sem.pickle', 'rb') as handle:
    qPCR_sem = pickle.load(handle)
with open(f'{output_path}/220815_raw_data_real_rmsd.pickle', 'rb') as handle:
    raw_data_rmsd = pickle.load(handle)
for guide in ['gdhA_1_42_B_MM14', 'gdhA_3_216_B_MM8']:  # Off target guides, as shown in notebook 2
    sp = guide.split('_')
    qPCR_vals[sp[0]].pop(guide)
    qPCR_sem[sp[0]].pop(guide)

### Define analysis functions

In [None]:
def hill_fit(exp_growth_df, gr_sem, rep_mean, rep_sem, genes, figure_path):
    """Fit single gene expression-growth curves with a two-parameter sigmoid function, then plot the data and model
    Arguments:
    exp_growth_df: Growth rate averages for all single-gene expression perturbations
    gr_sem: Growth rate SEM for all single-gene expression perturbations
    rep_mean: qPCR averages for all single-gene expression perturbations
    rep_sem: qPCR SEM for all single-gene expression perturbations
    genes: List of gene names
    figure_path: Path to a folder storing figure output
    Returns:
    hill_elements: Best fit parameters for each gene's expression-growth rate sigmoid fit
    """
    hill_elements, fig_key, sorted_colors = {}, ['2E', '2F', '2L', '2M', '2I', '2J', '2K', '2G', '2H', '1C'], \
                ['xkcd:dark gray', 'xkcd:cherry red', 'xkcd:sky blue', 'xkcd:forest green', 'xkcd:turquoise',
                 'xkcd:purple', 'xkcd:grass green', 'xkcd:sea blue', 'xkcd:dark yellow', 'xkcd:lilac']
    for i, gene_name in enumerate(genes + ['thyA']):
        single_growth, _, _ = exp_extract(exp_growth_df, rep_mean, gene_name, gene_name)
        gr_mean_df = single_growth['negC_rand_42']
        gr_sem_list = [gr_sem.loc[sgRNA, 'negC_rand_42'] for sgRNA in gr_mean_df.index]
        KD_mean_list = [rep_mean[gene_name][sgRNA] for sgRNA in single_growth.index]
        KD_sem_list = [rep_sem[gene_name][sgRNA] for sgRNA in single_growth.index]
        gr_mean_list, gr_sem_list, KD_mean_list, KD_sem_list = np.array(gr_mean_df.values), np.array(gr_sem_list), \
                                                               np.array(KD_mean_list), np.array(KD_sem_list)
        mask_nans = ~np.isnan(gr_mean_list)  # Remove nans from fit points
        gr_mean_list, gr_sem_list, KD_mean_list, KD_sem_list = gr_mean_list[mask_nans], gr_sem_list[mask_nans], \
                                                               KD_mean_list[mask_nans], KD_sem_list[mask_nans]
        hill_elements[gene_name] = least_squares(residuals_growth_rate, x0=[0.5, 1],
                                        bounds=[(0, -np.inf), (np.inf, np.inf)], args=(gr_mean_list, KD_mean_list)).x
        xVals = np.arange(np.floor(min(KD_mean_list)), 1.25, 0.01)
        fitVals = growth_rate(xVals, hill_elements[gene_name][0], hill_elements[gene_name][1])
        fig, ax = plt.subplots(figsize=(5, 5))
        if i == 9:
            for j, KD in enumerate(KD_mean_list):
                ax.errorbar(KD, gr_mean_list[j], xerr=KD_sem_list[j], yerr=gr_sem_list[j], fmt='o',
                            color=sorted_colors[j], ms=10, elinewidth=2, mec='xkcd:dark gray')
        else:
            ax.errorbar(KD_mean_list, gr_mean_list, xerr=KD_sem_list, yerr=gr_sem_list, fmt='o',
                        color='xkcd:dark gray', ms=10, elinewidth=2)
        ax.plot(xVals, fitVals, '-b', lw=3)
        ax.tick_params(axis='both', labelsize=18)
        ax.set_xlabel('Repression Efficiency', fontsize=20)
        ax.set_ylabel('Relative GR', fontsize=20)
        ax.set_title(gene_name, loc='center', fontsize=20)
        ax.set_ylim([0, 1.17])
        ax.set_yticks([0.5, 1])
        plt.tight_layout()
        plt.savefig(f'{figure_path}/Fig{fig_key[i]}.pdf')
        plt.show()
    return hill_elements


def exp_extract(exp_growth_df, rep_mean, gene_name_1, gene_name_2):
    """Function to sort and extract relevant growth rates for a given pair of genes.
    Arguments:
    exp_growth_df: Growth rate averages for all single-gene expression perturbations
    rep_mean: qPCR averages for all single-gene expression perturbations
    gene_name_1: Name of one gene in the pair
    gene_name_2: Name of the other gene in the pair
    Returns:
    pair_data: Pairwise expression-growth rate data for the gene pair
    gene1: Gene titrated along the index of the pairwise data
    gene2: Gene titrated along the columns of the pairwise data
    """
    for guide in exp_growth_df.index:
        # The following logic establish which gene lies on which axis
        if guide.startswith(gene_name_1):
            gene1 = gene_name_1
            gene2 = gene_name_2
        elif guide.startswith(gene_name_2):
            gene1 = gene_name_2
            gene2 = gene_name_1
    gene1_sgRNAs = list(rep_mean[gene1])
    gene2_sgRNAs = list(rep_mean[gene2])
    pair_data = pd.DataFrame(np.full((len(gene1_sgRNAs), len(gene2_sgRNAs)), np.nan), gene1_sgRNAs, gene2_sgRNAs)
    for sgRNA2 in gene2_sgRNAs:
        for sgRNA1 in gene1_sgRNAs:
            if sgRNA1 == 'negC_rand_42':  # Nontargeting sgRNAs are indexed differently
                 pair_data.loc['negC_rand_42', sgRNA2] = exp_growth_df.loc[sgRNA2, 'negC_rand_42']
            else:
                pair_data.loc[sgRNA1, sgRNA2] = exp_growth_df.loc[sgRNA1, sgRNA2]
    # Sort rows and columns by repression intensity
    pair_data.loc[len(pair_data.index)] = [rep_mean[gene2][sgRNA] for sgRNA in pair_data.columns]
    pair_data = pair_data.sort_values(by=len(pair_data.index)-1, axis=1, ascending=True)
    pair_data = pair_data.drop(labels=len(pair_data.index)-1, axis=0)
    pair_data[len(pair_data.columns)] = [rep_mean[gene1][sgRNA] for sgRNA in pair_data.index]
    pair_data = pair_data.sort_values(by=len(pair_data.columns)-1, axis=0, ascending=True)
    pair_data = pair_data.drop(labels=len(pair_data.columns)-1, axis=1)
    return pair_data, gene1, gene2


def residuals_growth_rate(p, y, x):
    """Function to compute residuals between predicted and measured growth rates.
    Arguments:
    p: Growth rate parameters Ro and n
    y: Actual growth rate measurements
    x: Repression values for each perturbation
    Returns:
    err: Residual bewteen prediction and measured growth rates
    """
    Ro, n = p
    err = y - growth_rate(x, Ro, n)
    return err


def growth_rate(r, Ro, n):
    """Calculates an expected growth rate using a sigmoidal formula.
    Arguments:
    r: Repression level to use when predicting growth rates
    Ro: Repression level at half-maximal growth rate
    n: Steepness of the repression-growth rate function
    Returns:
    g_rate: Predicted growth rate
    """
    if len(r) > 1:  # Check if repression values are a list or just an individual value
        g_rate = [(1 / (1+math.exp(n*(repression-Ro)))) for repression in r]
    else:
        g_rate = 1 / (1+math.exp(n*(r-Ro)))
    return g_rate

In [None]:
def subsample_gr(exp_growth_df, rep_mean, hill_params, frac_drop=0):
    """Generate three groups of pairwise expression-growth rate data.
    The first group includes all data gathered in the pairwise experiment.
    The second contains only growth rates following two targeting CRISPRi perturbations.
    The final group has only pairwise CRISPRi data, and is subsampled according to frac_drop.
    Arguments:
    exp_growth_df: Growth rate averages for all single-gene expression perturbations
    rep_mean: qPCR averages for all single-gene expression perturbations
    hill_params: Best fit parameters for each gene's expression-growth rate sigmoid fit
    frac_drop: A fraction (between 0 and 1) of the pairwise data to remove for subsampling
    Returns:
    full_data: The first group of data, containing all measurements
    pairwise_data: The second group of data, containing only pairwise measurements
    subsampled_data: The third group of data, containing subsampled pairwise measurements
    """
    full_data, pairwise_data, subsampled_data = {}, {}, {}
    for i, gene_1 in enumerate(list(hill_params)):
        for gene_2 in list(hill_params)[i+1:]:
            gene_pair = (gene_1, gene_2)
            full_data[gene_pair], gene1, gene2 = exp_extract(exp_growth_df, rep_mean, gene_1, gene_2)
            ind_names, col_names = full_data[gene_pair].index, full_data[gene_pair].columns
            gene1_KD = [rep_mean[gene_pair[1]][treatment1] for treatment1 in full_data[gene_pair].index]
            gene2_KD = [rep_mean[gene_pair[0]][treatment2] for treatment2 in full_data[gene_pair].columns]
            # The following subsamples data from the complete landscape
            pairwise_data[gene_pair] = full_data[gene_pair].copy()  # Don't change the original
            pairwise_data[gene_pair].loc['negC_rand_42'] = np.nan  # Remove first-order perturbations
            pairwise_data[gene_pair]['negC_rand_42'] = np.nan  # Remove first-order perturbations
            subsampled_data[gene_pair] = pairwise_data[gene_pair].copy()  # Don't change the original
            negC_row = list(ind_names).index('negC_rand_42')  # Don't subsample from first-order perturbations
            negC_col = list(col_names).index('negC_rand_42')  # Don't subsample from first-order perturbations
            tot_ind = list(range((len(ind_names)-1)*(len(col_names)-1)))  # List of all indexes
            np.random.shuffle(tot_ind)  # Randomize index order
            for drop in range(int(np.ceil(len(tot_ind)*frac_drop))):  # Drop the first frac_drop percent
                row_ind, col_ind = int(tot_ind[drop]%(len(ind_names)-1)), \
                                   int(np.floor(tot_ind[drop]/(len(ind_names)-1)))
                if row_ind >= negC_row:  # Renumber to skip the first-order perturbations
                    row_ind += 1
                if col_ind >= negC_col:  # Renumber to skip the first-order perturbations
                    col_ind += 1
                row, col = ind_names[row_ind], col_names[col_ind]
                subsampled_data[gene_pair][col][row] = np.nan
    return full_data, pairwise_data, subsampled_data


def calc_avals(pairwise_data, rep_mean, hill_params, penalty=0):
    """Optimizes coupling constants for a given pair of genes using pairwise expression-growth rate data.
    Arguments:
    pairwise_data: Pairwise CRISPRi growth rate data to use when fitting a values
    rep_mean: qPCR averages for all single-gene expression perturbations
    hill_params: Best fit parameters for each gene's expression-growth rate sigmoid fit
    pentalty: Regularization term to penalize the absolute value of coupling constants
    Returns:
    pair_avals: Coupling constants for each gene pair
    """
    pair_avals = {}
    for i, gene_1 in enumerate(list(hill_params)):
        for gene_2 in list(hill_params)[i+1:]:
            gene_pair = (gene_1, gene_2)
            gene1_KD = [rep_mean[gene_pair[1]][treatment1] for treatment1 in pairwise_data[gene_pair].index]
            gene2_KD = [rep_mean[gene_pair[0]][treatment2] for treatment2 in pairwise_data[gene_pair].columns]
            pair_avals[gene_pair] = least_squares(residuals_growthPairs, x0=[0, 0], bounds=[(-1.2, -1.2), (10, 10)],
                                        args=(pairwise_data[gene_pair].values, gene1_KD, gene2_KD,
                                        hill_params[gene_2], hill_params[gene_1], penalty)).x
    return pair_avals


def residuals_growthPairs(avals, y, rep_1, rep_2, p1, p2, penalty):
    """Calculate residuals between pairwise growth rate measurements and predictions.
    The goal is to minimize the regularized RMSD of our predictions. To regularize,
    we add a penalty term based on the absolute values of each parameter in avals.
    Finally, we take the square root of this regularized RMSD, as least_squares technically
    optimizes the squared residual, but we want to minimize the value directly.
    The result of optimization is only numerically (<10**-8) different without the square root.
    Arguments:
    avals: Coupling constants between the genes of interest
    y: True pairwise growth rates
    rep_1: All repression values for gene 1
    rep_2: All repression values for gene 2
    p1: Growth rate parameters Ro and n for gene 1
    p2: Growth rate parameters Ro and n for gene 2
    penalty: Regularization term to penalize the absolute value of coupling constants
    Returns:
    err: Regularized root mean square error between the predicted and measured growth rates
    """
    gpair = pairGrowthRates(avals, rep_1, rep_2, p1, p2)
    mask = ~np.isnan(y.flatten())  # Use to mask nans from input data
    gpair = gpair.flatten()[mask]
    y = y.flatten()[mask]
    err = y - gpair
    err = np.sqrt(nanrms(err) + sum([penalty*abs(x) for x in avals]))
    return err


def pairGrowthRates(avals, rep_1, rep_2, p1, p2):
    """Calculate predicted pairwise growth rates following two coupled gene knockdowns.
    Arguments:
    avals: Coupling constants between the genes of interest
    rep_1: All repression values for gene 1
    rep_2: All repression values for gene 2
    p1: Growth rate parameters Ro and n for gene 1
    p2: Growth rate parameters Ro and n for gene 2
    Returns:
    gr_pair: Predicted pairwise growth rates
    """
    Ro1, n1 = p1
    Ro2, n2 = p2
    r1eff, r2eff, resid_1, resid_2 = solveReff(avals, rep_1, rep_2, Ro1, Ro2)
    # On rare occassions, early rounds of optimization can generate abnormally large effective repression values
    # This logic manually "resets" the optimization, and testing has shown that this is sufficient
    if max([abs(x) for x in r1eff.flatten()]) > 10 or max([abs(x) for x in r2eff.flatten()]) > 10:
        return np.zeros(np.shape(r1eff))
    gr1, gr2, gr_pair = np.zeros(np.shape(r1eff)), np.zeros(np.shape(r2eff)), np.zeros(np.shape(r1eff))
    for i, r1 in enumerate(r1eff):
        gr1[i, :] = growth_rate(r1, Ro1, n1)
    for i, r2 in enumerate(r2eff):
        gr2[i, :] = growth_rate(r2, Ro2, n2)
    for r1 in range(np.shape(r1eff)[0]):
        for r2 in range(np.shape(r1eff)[1]):
            gr_pair[r1, r2] = gr1[r1, r2] * gr2[r1, r2]
    return gr_pair


def solveReff(avals, rep_1, rep_2, Ro_1, Ro_2):
    """Solve for the effective repression of two CRISPRi perturbations given two coupling constants.
    Arguments:
    avals: Coupling constants between the genes of interest
    rep_1: All repression values for gene 1
    rep_2: All repression values for gene 2
    Ro: Growth rate parameter Ro for gene 1
    Ro: Growth rate parameter Ro for gene 2
    Returns:
    r1eff: 2D array of the first perturbation's relative repression after accounting for coupling
    r1eff: 2D array of the second perturbation's relative repression after accounting for coupling
    resids_r1: List of residuals for the first perturbation, returned for troubleshooting and optimization
    resids_r2: List of residuals for the second perturbation, returned for troubleshooting and optimization
    """
    r1eff = np.tile(np.array(rep_1), (len(rep_2), 1)).T
    r2eff = np.tile(np.array(rep_2), (len(rep_1), 1))
    r1update, r2update = np.zeros(np.shape(r1eff)), np.zeros(np.shape(r2eff))
    resids, eps, count = np.inf, 0.01, 0  # Initialize sum of residuals, desired final residual, and count iterator
    resids_r1, resids_r2 = [], []
    # Iteratively reduce the residuals using the update formulas
    # If the desired residual is not reached, exit after 100 iterations
    while resids > eps and count < 100:
        for i, r1 in enumerate(rep_1):
            for j, r2 in enumerate(rep_2):
                r1update[i, j] = r1 / (1 + avals[0]*((r2eff[i, j]/Ro_2)/(1 + (r2eff[i, j]/Ro_2))))
                r2update[i, j] = r2 / (1 + avals[1]*((r1eff[i, j]/Ro_1)/(1 + (r1eff[i, j]/Ro_1))))
        resids_r1.append(np.sum(abs(r1eff - r1update)))
        resids_r2.append(np.sum(abs(r2eff - r2update)))
        resids = resids_r1[-1] + resids_r2[-1]
        r1eff = np.copy(r1update)
        r2eff = np.copy(r2update)
        count += 1
    return r1eff, r2eff, resids_r1, resids_r2


def nanrms(x, axis=None):
    """Self-defined root mean square function. Used for convenience and consistency.
    Arguments:
    x: Array of residuals
    axis: In case of an array of arrays, enter the desired axis to calculate across
    Directly returns the RMSD of the residuals provided
    """
    return np.sqrt(np.nanmean(x**2, axis=axis))


#function for calculating the multiplicative model
# Calculate multiplicative Bliss prediction
def pred_mult_calculate(exp_growth_df):
    """Calculate predicted pairwise growth rates using a multiplicative Bliss model
    Arguments:
    exp_growth_df: Growth rate averages for all single-gene expression perturbations
    Returns:
    mult_gr: Predicted growth rates from a multiplicative Bliss model
    """
    mult_gr = pd.DataFrame(np.full((len(exp_growth_df.index), len(exp_growth_df.columns)), np.nan),
                           exp_growth_df.index, exp_growth_df.columns)
    for sgRNA1 in exp_growth_df.index:
        for sgRNA2 in exp_growth_df.columns:
            mult_gr.loc[sgRNA1, sgRNA2] = exp_growth_df['negC_rand_42'][sgRNA1]*exp_growth_df[sgRNA2]['negC_rand_42']
    return mult_gr

In [None]:
def calc_pair_gr(exp_growth_df, rep_mean, avals, pairwise_data, hill_params):
    """Calculate growth rates following pairwise CRISPRi for all gene pairs.
    Arguments:
    exp_growth_df: Growth rate averages for all single-gene expression perturbations
    rep_mean: qPCR averages for all single-gene expression perturbations
    avals: Coupling constants for each gene pair
    pairwise_data: Pairwise CRISPRi growth rate data
    hill_params: Best fit parameters for each gene's expression-growth rate sigmoid fit
    Returns:
    pred_dict: All predicted growth rates
    """
    pred_dict = {}
    for gene_pair in avals:
        pair_gr, gene1, gene2 = exp_extract(exp_growth_df, rep_mean, gene_pair[0], gene_pair[1])
        gene1_KD = [rep_mean[gene_pair[1]][treatment1] for treatment1 in pairwise_data[gene_pair].index]
        gene2_KD = [rep_mean[gene_pair[0]][treatment2] for treatment2 in pairwise_data[gene_pair].columns]
        pred_dict[gene_pair] = pairGrowthRates(avals[gene_pair], gene1_KD, gene2_KD, hill_params[gene1],
                                               hill_params[gene2])
    return pred_dict


def heat_map(df, ax=None, xticks=None, yticks=None, xlabel=None, ylabel=None, vmin=0, vmax=2, cmap='PuOr_r'):
    """Creates a heat map. Function made for ease of use.
    Arguments: 
    df: Data to plpot
    ax: Axes object, if available
    xticks, yticks, xlabel, ylabel, vmin, vmax, cmap: Self-explanatory
    Returns:
    ax: Resulting axes object
    cb: Colorbar
    """
    if not ax:
        ax = plt.gca()
    cb = ax.imshow(df, cmap=cmap, interpolation='none', vmin=vmin, vmax=vmax)
    ax.set_xlabel(xlabel, fontsize=24)
    ax.set_ylabel(ylabel, fontsize=24)
    if xticks:
        ax.set_xticks(np.arange(len(xticks)))
        ax.set_xticklabels(xticks, fontsize=18)
        ax.xaxis.tick_top()
    else:
        ax.set_xticks([])
    if yticks:
        ax.set_yticks(np.arange(len(yticks)))
        ax.set_yticklabels(yticks, fontsize=18)
    else:
        ax.set_yticks([])
    ax.spines[['top', 'right']].set_visible(True)
    ax.set_facecolor("darkgrey")
    return ax, cb


def cross_validate(pairwise_data, subsampled_data, sub_prediction):
    """Calculates the error of predicted growth rates in a holdout data set
    Arguments:
    pairwise_data: Pairwise CRISPRi growth rate data
    subsampled_data: Subsampled pairwise CRISPRi growth rate data used in training
    sub_prediction: Predicted growth rates
    Returns:
    rms_dict_sub: Dictionary containing RMSD metrics for all pairwise predictions
    """
    rms_dict_sub = {}
    for gene_pair in pairwise_data:
        sub_err = []
        for j, sgRNA1 in enumerate(list(pairwise_data[gene_pair].index)):
            for k, sgRNA2 in enumerate(list(pairwise_data[gene_pair].columns)):
                if np.isnan(subsampled_data[gene_pair].loc[sgRNA1, sgRNA2]):  # If not in training set
                    sub_err.append(pairwise_data[gene_pair].loc[sgRNA1, sgRNA2] - sub_prediction[gene_pair][j, k])
        rms_dict_sub[gene_pair] = nanrms(np.array(sub_err))
    return rms_dict_sub

### Fit single expression-growth curves

We fit individual sigmoids to single-gene expression-growth rate data. These sigmoids are described by two parameters, $R_0$ (repression at half-maximal growth rate) and n (steepness).

In [None]:
hill_elements = hill_fit(growth_df_rescale, sem_df_rescale, qPCR_vals, qPCR_sem, gene_names, figure_path)

### Fit coupling constants and calculate predicted growth rates

We fit coupling constants to pairwise expression-growth rate data, then use these coupling constants and single-gene sigmoid parameters to generate predicted growth rates for each CRISPRi perturbation. We also predict growth rates using subsampled pairwise growth rate data as well as a Null model without coupling and a multiplicative model that considers only sgRNA identity and cannot make predictions based on expression changes.

In [None]:
reg, sub = 10**-1.25, 0.8
fit_gene_all, fit_gene_pairs, fit_gene_pairs_sub = subsample_gr(growth_df_rescale, qPCR_vals, hill_elements, sub)
pair_avals = calc_avals(fit_gene_pairs, qPCR_vals, hill_elements, reg)
pair_avals_sub = calc_avals(fit_gene_pairs_sub, qPCR_vals, hill_elements, reg)
pair_avals_null, mult_pred = {}, {}
for gene_pair in pair_avals_sub:
    mult_pred[gene_pair] = pred_mult_calculate(fit_gene_all[gene_pair])
    pair_avals_null[gene_pair] = np.array([0, 0])

In [None]:
pred_dict = calc_pair_gr(growth_df_rescale, qPCR_vals, pair_avals, fit_gene_all, hill_elements)
pred_dict_sub = calc_pair_gr(growth_df_rescale, qPCR_vals, pair_avals_sub, fit_gene_all, hill_elements)
pred_dict_null = calc_pair_gr(growth_df_rescale, qPCR_vals, pair_avals_null, fit_gene_all, hill_elements)

### Calculate model performance

For all predictions, we two calculate error metrics: AIC and RMSD. AIC accounts for the difference in the number of parameters between models, penalizing those with additional parameters, while RMSD is blind to these differences. We see that the full model outperforms all others, even when accounting for its additional parameters. While the subsampled model's exact error metrics will change between subsampling iterations, we see that it usually outperforms the other models as well, with only slightly diminished accuracy.

In [None]:
rms_dict, rms_dict_sub, rms_dict_null = {}, {}, {}
full_model_err, full_sub_err, full_mult_err, full_null_err = [], [], [], []
for gene_pair in pair_avals:
    expt_genepair = fit_gene_all[gene_pair].values
    pred_mult_genepair = mult_pred[gene_pair].values
    err_mult, err_model, err_sub, err_null = [], [], [], []
    for j, guide2 in enumerate(list(fit_gene_all[gene_pair].index)):
        if guide2 != 'negC_rand_42':
            for k, guide1 in enumerate(list(fit_gene_all[gene_pair].columns)):
                if guide1 != 'negC_rand_42':
                    err_mult.append(expt_genepair[j, k] - pred_mult_genepair[j, k])
                    err_model.append(expt_genepair[j, k] - pred_dict[gene_pair][j, k])
                    err_sub.append(expt_genepair[j, k] - pred_dict_sub[gene_pair][j, k])
                    err_null.append(expt_genepair[j, k] - pred_dict_null[gene_pair][j, k])
    full_model_err += err_model
    full_sub_err += err_sub
    full_mult_err += err_mult
    full_null_err += err_null
    rms_dict[gene_pair] = nanrms(np.array(err_model))
    rms_dict_sub[gene_pair] = nanrms(np.array(err_sub))
    rms_dict_null[gene_pair] = nanrms(np.array(err_null))
avg_rms_model = nanrms(np.array(full_model_err))
avg_rms_sub = nanrms(np.array(full_sub_err))
avg_rms_mult = nanrms(np.array(full_mult_err))
avg_rms_null = nanrms(np.array(full_null_err))
model_k = 2*(len(hill_elements)) + 2*len(pair_avals)
null_k = 2*(len(hill_elements))
mult_k = 0
mult_aic = 2*mult_k + sum(~np.isnan(full_mult_err))*np.log(np.nanmean([x**2 for x in full_mult_err]))
print(f'Multiplicative AIC: {mult_aic}')
null_aic = 2*null_k + sum(~np.isnan(full_null_err))*np.log(np.nanmean([x**2 for x in full_null_err]))
print(f'Null AIC:           {null_aic}')
model_aic = 2*model_k + sum(~np.isnan(full_model_err))*np.log(np.nanmean([x**2 for x in full_model_err]))
print(f'Model AIC:          {model_aic}')
sub_aic = 2*model_k + sum(~np.isnan(full_sub_err))*np.log(np.nanmean([x**2 for x in full_sub_err]))
print(f'{np.round((1-sub)*100)}% AIC:          {sub_aic}', end='\n\n')
print(f'Multiplicative RMSD: {np.round(avg_rms_mult, 4)}')
print(f'Null RMSD:           {np.round(avg_rms_null, 4)}')
print(f'Model RMSD:          {np.round(avg_rms_model, 4)}')
print(f'{np.round((1-sub)*100)}% RMSD:          {np.round(avg_rms_sub, 4)}')

### Plot continuous epistasis model performance against Null model performance for all gene pairs

We plot the prediction RMSD across every gene pair for both the complete continuous epistasis mmodel and the Null model. The continuous epistasis model is equivalent or superior for all gene pairs (which is expected), and approaches the limit of experimental reproducibility.

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
rms_null, rms_model = [], []
for gene_pair in rms_dict:
    rms_null.append(rms_dict_null[gene_pair])
    rms_model.append(rms_dict[gene_pair])
    if gene_pair in [('dapA', 'purN'), ('gdhA', 'gltB'), ('dapA', 'dapB'), ('purN', 'purL')]:
        ax.scatter(rms_dict_null[gene_pair], rms_dict[gene_pair], s=150, color='w', edgecolors='k', zorder=1)
    else:
        ax.scatter(rms_dict_null[gene_pair], rms_dict[gene_pair], color='xkcd:dark gray', s=150, edgecolors='k',
                   zorder=1)
ax.set_xlim([0, 0.34])
ax.set_ylim([0, 0.34])
ax.plot([0, 0.34], [0, 0.34], lw=1, ls='--', color='xkcd:gray', zorder=0)
ax.set_xlabel('Null Model RMSD', fontsize=20)
ax.set_ylabel('Epistatic Model RMSD', fontsize=20)
ax.tick_params(axis='both', labelsize=18)
ax.set_xticks([0.1, 0.2, 0.3])
ax.set_yticks([0.1, 0.2, 0.3])
plt.tight_layout()
plt.savefig(f'{figure_path}/Fig4A.pdf')
plt.show()

### Visualize epistasis from the continuous epistasis model, a subsampled model, and a simple Bliss epistasis

Coupling values are shown from both the complete and subsampled continuous epistasis models. In addition, Bliss epistasis is calculated for the most severe knockdowns (this approximates a classical knockout study, where epistasis is calculated only at this severe perturbation level). We see that gene-gene epistasis may be most prevalent at these extreme levels, potentially due to a decrease in the signal-to-noise ratio we observe at these low growth rates. When considering the entire expression-growth rate landscape, which we approximate with numerous titrated CRISPRi measurements, gene-gene epistasis is fairly rare.

In [None]:
edge = 0.6
aval_matrix = pd.DataFrame(np.zeros((len(gene_names), len(gene_names))), gene_names, gene_names)
for gene_pair in pair_avals:
    aval_matrix.loc[gene_pair[0], gene_pair[1]] = pair_avals[gene_pair][0]
    aval_matrix.loc[gene_pair[1], gene_pair[0]] = pair_avals[gene_pair][1]
for gene in gene_names:
    aval_matrix.loc[gene, gene] = np.nan
fig, ax = plt.subplots(figsize=(10, 10))
ax, cb = heat_map(aval_matrix, ax=ax, xticks=[f'$\it{x}$' for x in gene_names],
                  yticks=[f'$\it{x}$' for x in gene_names], vmin=-edge, vmax=edge, cmap='PuOr_r')
plt.title('Pairwise Couplings ($a_{ij}$/$a_{ji}$)', fontsize=24)
plt.colorbar(cb)
plt.savefig(f'{figure_path}/Fig4F.pdf')
plt.show()

In [None]:
aval_matrix = pd.DataFrame(np.zeros((len(gene_names), len(gene_names))), gene_names, gene_names)
for gene_pair, data in fit_gene_all.items():
    last_row = data.index[-1]
    last_col = data.columns[-1]
    if ~np.isnan(data.loc[last_row, last_col]):
        aval_matrix.loc[gene_pair[0], gene_pair[1]] = data.loc[last_row, last_col] \
                                            - data.loc['negC_rand_42', last_col]*data.loc[last_row, 'negC_rand_42']
        aval_matrix.loc[gene_pair[1], gene_pair[0]] = data.loc[last_row, last_col] \
                                            - data.loc['negC_rand_42', last_col]*data.loc[last_row, 'negC_rand_42']
    else:  # Take average of epistasis from adjacent points
        second_last_row = data.index[-2]
        second_last_col = data.columns[-2]
        point1 = data.loc[second_last_row, last_col] - data.loc['negC_rand_42',
                                                                last_col]*data.loc[second_last_row, 'negC_rand_42']
        point2 = data.loc[last_row, second_last_col] - data.loc['negC_rand_42',
                                                                second_last_col]*data.loc[last_row, 'negC_rand_42']
        aval_matrix.loc[gene_pair[0], gene_pair[1]] = np.mean([point1, point2])
        aval_matrix.loc[gene_pair[1], gene_pair[0]] = np.mean([point1, point2])

for gene in gene_names:
    aval_matrix.loc[gene, gene] = np.nan
fig, ax = plt.subplots(figsize=(10, 10))
ax, cb = heat_map(aval_matrix, ax=ax, xticks=[f'$\it{x}$' for x in gene_names],
                    yticks=[f'$\it{x}$' for x in gene_names], vmin=-edge, vmax=edge)
plt.title('Double Knockdown Epistasis', fontsize=24)
plt.colorbar(cb)
plt.savefig(f'{figure_path}/Fig4G.pdf')
plt.show()

### Visualize pairwise expression-growth rate predictions for certain gene pairs

In [None]:
fig_numbers = ['4B', '4C', '4D', '4E']
for i, gene_pair in enumerate([('dapA', 'purN'), ('gdhA', 'gltB'), ('dapA', 'dapB'), ('purN', 'purL')]):
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(fit_gene_pairs[gene_pair].values.flatten(), pred_dict_null[gene_pair].flatten(), color='b',
               label='Null', edgecolors='k', linewidths=0.5, s=100, alpha=1)
    ax.scatter(fit_gene_pairs[gene_pair].values.flatten(), pred_dict[gene_pair].flatten(), color='xkcd:orange',
               label='Model', edgecolors='k', linewidths=0.5, s=100, alpha=1)
    ax.set_xlim([0, 1.2])
    ax.set_ylim([0, 1.2])
    ax.plot([0, 1.2], [0, 1.2], lw=1, color='xkcd:gray', ls='--', zorder=0)
    ax.legend(fontsize=18)
    ax.set_xlabel('Exp. Growth Rate', fontsize=20)
    ax.set_ylabel('Predicted Growth Rate', fontsize=20)
    ax.set_xticks([0.5, 1])
    ax.set_yticks([0.5, 1])
    ax.tick_params(axis='both', labelsize=18)
    plt.tight_layout()
    plt.savefig(f'{figure_path}/Fig{fig_numbers[i]}.pdf')
    plt.show()

### Import subsampling and regularization optimization data and visualize results

We optimized subsampling and regularization parameters by repeating the subsampling procedure 100 times, with many different internal parameters. This code should be run on a computing cluster if it is to be repeated. We have saved the results of our repeated iterations and summarize results below. The code used to generate these files is below in two commented out cells (labeled Subsampling optimization and Regularization optimization). As these code blocks require random subsampling, future iterations will not produce exactly the same output files (though the error metrics, averaged over 100 iterations, will stay within a narrow range).

In [None]:
with open(f'{optimization_path}/220815_reg_rms_dict_sub.pickle', 'rb') as handle:
    cross_valid_rmsd = pickle.load(handle)
with open(f'{optimization_path}/220815_sub_pair_avals_sub.pickle', 'rb') as handle:
    subsampling_avals = pickle.load(handle)
with open(f'{optimization_path}/220815_subsampling_example.pickle', 'rb') as handle:
    subsampling_grs = pickle.load(handle)
with open(f'{optimization_path}/220815_sub_rmsd.pickle', 'rb') as handle:
    sub_rmsd = pickle.load(handle)

In [None]:
for i in [0, 1]:
    ax, cb = heat_map(subsampling_grs[i], xlabel='dapA', ylabel='purN', cmap='RdBu_r')
    plt.savefig(f'{figure_path}/FigS6A{i+1}.pdf')
    plt.show()
    print(f"Coupling constants: {subsampling_avals[0.8][i][10**-1.25][('dapA', 'purN')]}")

In [None]:
subsampling_level = 0.8
rms_dict_sub_range, full_sub_rms, rms_dict_sub_range = {}, [], {gene_pair : [] for gene_pair in 
                                                                subsampling_avals[subsampling_level][0][10**-1.25]}
for i in subsampling_avals[subsampling_level]:
    full_err_sub = []
    for gene_pair in subsampling_avals[subsampling_level][i][10**-1.25]:
        expt_genepair = fit_gene_all[gene_pair].values
        temp_dict, err_sub = {}, []
        temp_dict[gene_pair] = subsampling_avals[subsampling_level][i][10**-1.25][gene_pair]
        temp_pred = calc_pair_gr(growth_df_rescale, qPCR_vals, temp_dict, fit_gene_all, hill_elements)
        for j, guide2 in enumerate(list(fit_gene_all[gene_pair].index)):
            for k, guide1 in enumerate(list(fit_gene_all[gene_pair].columns)):
                if guide1 != 'negC_rand_42' and guide2 != 'negC_rand_42':
                    err_sub.append(expt_genepair[j, k] - temp_pred[gene_pair][j, k])
                    full_err_sub.append(expt_genepair[j, k] - temp_pred[gene_pair][j, k])
        rms_dict_sub_range[gene_pair].append(nanrms(np.array(err_sub)))
    full_sub_rms.append(nanrms(np.array(full_err_sub)))

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
rms_null, rms_sub = [], []
for gene_pair in rms_dict_sub_range:
    rms_null.append(rms_dict_null[gene_pair])
    rms_sub.append(np.mean(rms_dict_sub_range[gene_pair]))
    ax.errorbar(rms_dict_null[gene_pair], np.mean(rms_dict_sub_range[gene_pair]), fmt='o', ms=10, elinewidth=3,
                color='xkcd:dark gray', yerr=np.std(rms_dict_sub_range[gene_pair]))
ax.set_xlim([0, 0.34])
ax.set_ylim([0, 0.34])
ax.plot([0, 0.34], [0, 0.34], color='xkcd:gray', lw=1)
ax.set_xticks([0.1, 0.2, 0.3])
ax.set_yticks([0.1, 0.2, 0.3])
ax.set_ylabel('Subsampled Model RMSD', fontsize=20)
ax.set_xlabel('Null Model RMSD', fontsize=20)
plt.tight_layout()
plt.savefig(f'{figure_path}/FigS6B.pdf')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
rms_full, rms_sub = [], []
for gene_pair in rms_dict_sub_range:
    rms_full.append(rms_dict[gene_pair])
    rms_sub.append(np.mean(rms_dict_sub_range[gene_pair]))
    ax.errorbar(rms_dict[gene_pair], np.mean(rms_dict_sub_range[gene_pair]), fmt='o', ms=10, elinewidth=3,
                color='xkcd:dark gray', yerr=np.std(rms_dict_sub_range[gene_pair]))
ax.set_xlim([0, 0.34])
ax.set_ylim([0, 0.34])
ax.plot([0, 0.34], [0, 0.34], color='xkcd:gray', lw=1, ls='--')
ax.set_xticks([0.1, 0.2, 0.3])
ax.set_yticks([0.1, 0.2, 0.3])
ax.set_ylabel('Subsampled Model RMSD', fontsize=20)
ax.set_xlabel('Full Model RMSD', fontsize=20)
plt.tight_layout()
plt.savefig(f'{figure_path}/Fig4H.pdf')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
sns.violinplot(data=list(np.array(x) for x in sub_rmsd.values()), scale='width')
ax.set_xticklabels(list(sub_rmsd.keys()))
ax.axhline(avg_rms_null, ls='--', color='r')
ax.axhline(np.mean(sub_rmsd[0]), ls='--', color='b')
ax.set_ylabel('Subsampled RMSD', fontsize=18)
ax.set_xlabel('Fraction Dropped', fontsize=18)
plt.text(0.65, 0.135, 'Full Model RMSD', horizontalalignment='center', fontsize=14, color='b')
plt.text(0.65, 0.168, 'Null Model RMSD', horizontalalignment='center', fontsize=14, color='r')
plt.savefig(f'{figure_path}/FigS7.pdf')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
for i, sub in enumerate(cross_valid_rmsd):
    temp_rmsd_list = []
    for reg in cross_valid_rmsd[sub][0]:
        xVals = []
        for j in cross_valid_rmsd[sub]:
            xVals += list(cross_valid_rmsd[sub][j][reg].values())
        temp_rmsd_list.append(np.mean(xVals))
    if sub == 0.8:
        ax.scatter(list(cross_valid_rmsd[sub][0]), temp_rmsd_list, s=90, color=['b', 'xkcd:orange'][i],
                   edgecolor='xkcd:dark gray', label='20%')
    elif sub == 0.9:
        ax.scatter(list(cross_valid_rmsd[sub][0]), temp_rmsd_list, s=90, color=['b', 'xkcd:orange'][i],
               edgecolor='xkcd:dark gray', label='10%')
    else:
        ax.scatter(list(cross_valid_rmsd[sub][0]), temp_rmsd_list, s=90, color=['b', 'xkcd:orange'][i],
               edgecolor='xkcd:dark gray', label='{np.round((1-sub)*100, 2)}% Subsampling')
ax.set_xscale('log')
ax.set_xlabel('Regularization Weight', fontsize=20)
ax.set_ylabel('Model RMSD', fontsize=20)
ax.tick_params(axis='both', labelsize=18)
ax.set_title('Regularized Subsampling', fontsize=24)
ax.text(10**-4*0.7, 0.1585, '10%', color='xkcd:orange', fontsize=20)
ax.text(10**-4*0.7, 0.1435, '20%', color='b', fontsize=20)
ax.set_yticks([0.14, 0.15, 0.16])
ax.set_xticks([10**-4, 10**-2, 10**0, 10**2])
plt.tight_layout()
plt.savefig(f'{figure_path}/FigS5.pdf')
plt.show()

In [None]:
reg = 10**-1.25
for gene_pair in subsampling_avals[0][0][reg]:
    fig, ax = plt.subplots(figsize=(6, 6))
    full_aval1, full_aval2 = [], []
    for sub in subsampling_avals:
        full_aval1.append([subsampling_avals[sub][it][reg][gene_pair][0] for it in subsampling_avals[sub]])
        full_aval2.append([subsampling_avals[sub][it][reg][gene_pair][1] for it in subsampling_avals[sub]])
    a1 = ax.violinplot(full_aval1, [x*25 for x in list(subsampling_avals)], showmedians=True, showextrema=False)
    a2 = ax.violinplot(full_aval2, [x*25 + 0.5 for x in list(subsampling_avals)], showmedians=True, showextrema=False)
    for j, violin in enumerate([a1, a2]):
        for pc in violin['bodies']:
            pc.set_color(['b', 'xkcd:orange'][j])
    ax.hlines(subsampling_avals[0][0][reg][gene_pair][0], 0, 25, color='b', lw=0.5)
    ax.hlines(subsampling_avals[0][0][reg][gene_pair][1], 0, 25, color='xkcd:orange', lw=0.5)
    ax.set_title(f'$\it{gene_pair[0]}$ - $\it{gene_pair[1]}$', fontsize=20)
    ax.set_xlabel('Percent of Data Dropped', fontsize=20)
    ax.set_xticks([0, 5, 10, 15, 20, 25])
    ax.set_xticklabels([0, 20, 40, 60, 80, 100])
    ax.tick_params(axis='both', labelsize=18)
    ax.set_ylabel('Coupling Constant', fontsize=20)
    plt.show()

### Subsampling optimization

This code should only be run on a computing cluster and not locally. The commented out code below was used to generate the subsampling optimization files imported and plotted above.

In [None]:
"""
num_its, reg, sub_list = 100, 10**-1.25, [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
pair_avals_sub, sub_rmsd = {}, {}
for sub in sub_list:
    print(f'Subsampling: {np.round((1-sub)*100, 2)}')
    pair_avals_sub[sub], sub_rmsd[sub] = {}, []
    for i in range(num_its):
        pair_avals_sub[sub][i] = {}
        _, _, fit_gene_pairs_sub = subsample_gr(growth_df_rescale, qPCR_vals, hill_elements, sub)
        pair_avals_sub[sub][i][reg]  = calc_avals(fit_gene_pairs_sub, qPCR_vals, hill_elements, reg)
        full_err_sub = []
        for gene_pair, avals in pair_avals_sub[sub][i][reg].items():
            expt_genepair = fit_gene_all[gene_pair].values
            temp_dict = {}
            temp_dict[gene_pair] = avals
            temp_pred = calc_pair_gr(growth_df_rescale, qPCR_vals, temp_dict, fit_gene_all, hill_elements)
            for j, guide2 in enumerate(list(fit_gene_all[gene_pair].index)):
                for k, guide1 in enumerate(list(fit_gene_all[gene_pair].columns)):
                    if guide1 != 'negC_rand_42' and guide2 != 'negC_rand_42':
                        full_err_sub.append(expt_genepair[j, k] - temp_pred[gene_pair][j, k])
        sub_rmsd[sub].append(nanrms(np.array(full_err_sub)))
        if sub == 0:
            break
with open(f"{optimization_path}/{date}_sub_pair_avals_sub.pickle", 'wb') as handle:
    pickle.dump(pair_avals_sub, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(f"{optimization_path}/{date}_sub_rmsd.pickle", 'wb') as handle:
    pickle.dump(sub_rmsd, handle, protocol=pickle.HIGHEST_PROTOCOL)
"""
pass

### Regularization optimization

This code should only be run on a computing cluster and not locally. The commented out code below was used to generate the regularization optimization files imported and plotted above.

In [None]:
"""
num_its, sub_list = 100, [0.8, 0.9]
reg_list = [10**-4, 10**-3.5, 10**-3, 10**-2.5, 10**-2, 10**-1.75, 10**-1.5, 10**-1.25, 10**-1, 10**-0.5, 10**0,
            10**0.5, 10**1, 10**1.5, 10**2]
rms_dict_sub = {}
for sub in sub_list:
    print(f'Subsampling: {np.round((1-sub)*100, 2)}')
    rms_dict_sub[sub] = {}
    for i in range(num_its):
        rms_dict_sub[sub][i] = {}
        fit_gene_all, fit_gene_pairs, fit_gene_pairs_sub = subsample_gr(growth_df_rescale, qPCR_vals, hill_elements,
                                                                        sub)
        for reg in reg_list:
            pair_avals_sub = calc_avals(fit_gene_pairs_sub, qPCR_vals, hill_elements, reg)
            pred_dict_sub = calc_pair_gr(growth_df_rescale, qPCR_vals, pair_avals_sub, fit_gene_all, hill_elements)
            rms_dict_sub[sub][i][reg] = cross_validate(fit_gene_pairs, fit_gene_pairs_sub, pred_dict_sub)
with open(f"{optimization_path}/{date}_reg_rms_dict_sub.pickle", 'wb') as handle:
    pickle.dump(rms_dict_sub, handle, protocol=pickle.HIGHEST_PROTOCOL)
"""
pass

### Generate tables and export data

In [None]:
table_s4 = pd.DataFrame(np.full(((int(len(gene_names)*(len(gene_names)-1)/2)), 6), np.nan),
                        columns=['Gene 1', 'Gene 2', 'aij', 'aji', 'Model RMSD', 'Null RMSD'])
for i, (gene_pair, avals) in enumerate(pair_avals.items()):
    table_s4.loc[i] = [gene_pair[0], gene_pair[1], avals[0], avals[1], rms_dict[gene_pair], rms_dict_null[gene_pair]]

In [None]:
with open(f'{output_path}/{date}_hill_elements.pickle', 'wb') as handle:
    pickle.dump(hill_elements, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(f'{output_path}/{date}_pair_avals.pickle', 'wb') as handle:
    pickle.dump(pair_avals, handle, protocol=pickle.HIGHEST_PROTOCOL)
with pd.ExcelWriter(f'Supplementary_Tables.xlsx', mode='a', if_sheet_exists='replace') as writer:  
    table_s4.to_excel(writer, sheet_name='Table S4')