# The Continuous Epistasis Model for Gene-by-Environment Interactions

The following code trains and evaluates the continuous epistasis model on up to fourth-order gene-gene-environment-environment data. It uses single- and pairwise CRISPRi perturbations to relate gene expression (relative repression) to growth rate in a variety of media conditions. Single-perturbation growth rate relationships are modeled with two- or four-parameter sigmoidal functions, and perturbation coupling is accounted for using two coupling constants for every pair. The resulting model is assessed by its RMSD relative to a coupling-insensitive Null model. A subsampled model trained on only 20% of the available pairwise perturbation data performs similarly to the full model, supporting a sparse sampling strategy as a way to improve throughput in future experiments. Finally, predictions made using pairwise couplings alone are robust to third- and fourth-order combinations of perturbations.

3/7/22 - Phil Brown
8/7/22 - Ryan Otto

### Import packages and libraries

In [None]:
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.optimize import least_squares
import itertools
import warnings
import plot_defaults
plot_defaults.change_defaults()

### Define relevant variables

In [None]:
date = '220815'
output_path = 'intermediate_files'
optimization_path = 'input_files/optimization_files'
figure_path = 'Figures'
perturb_list = ['thymidine', 'methionine', 'folA_KD', 'thyA_KD']
perturb_dict = {'thymidine':[0, 0.05, 1, 2, 5, 10, 50], 'methionine':[0, 0.01, 0.02, 0.05, 0.1, 0.3, 1.0]}

### Import data

In [None]:
with open(f'{output_path}/220815_growth_rates_mean.pickle', 'rb') as handle:
    growth_rates_mean = pickle.load(handle)
with open(f'{output_path}/220815_repression_mean_subset.pickle', 'rb') as handle:
    repression_mean_subset = pickle.load(handle)
perturb_dict['folA_KD'] = list(repression_mean_subset['folA'].values())
perturb_dict['thyA_KD'] = list(repression_mean_subset['thyA'].values())

### Define analysis functions

In [None]:
def residuals_growth_rate(params, true_gr, perturbations):
    """Function to compute residuals between predicted and measured growth rates.
    Arguments:
    params: Growth rate parameters for each perturbation
    true_gr: Actual growth rate measurements
    perturbations: Repression values for each perturbation
    Returns:
    err: Residual bewteen prediction and measured growth rates
    """
    gr_est = growth_rate(perturbations, params)
    err = np.array(true_gr - gr_est)
    err = err[~np.isnan(err)]
    return err


def growth_rate(perturbations, params):
    """Calculates an expected growth rate using a sigmoidal formula.
    Arguments:
    perturbations: Perturbation intensity to use when predicting growth rates
    params: Sigmoidal parameters to use in predictions
    Returns:
    grates: Predicted growth rates
    """
    with warnings.catch_warnings():
        # During fitting, large exponentials can arise. When calculated, these return 0, which is desired.
        warnings.simplefilter('ignore', RuntimeWarning)
        if len(params) == 2:
            Do, ni = params
            if isinstance(perturbations, list):
                grates = np.array([1 / (1+np.exp(ni*(p-Do))) for p in perturbations])
            else:
                grates = np.array([1 / (1+np.exp(ni*(perturbations-Do)))])
        elif len(params) == 4:
            Do, ni, gmax, gmin = params
            if isinstance(perturbations, list):
                grates = np.array([gmin + ((gmax-gmin)/(1+np.exp(ni*(p-Do)))) for p in perturbations])
            else:
                grates = np.array([gmin + ((gmax-gmin)/(1+np.exp(ni*(perturbations-Do))))])
        return grates


def Hill_growth_curve_fxn(perturbations, params, ax):
    """Plots a smooth sigmoidal function
    Arguments:
    perturbations: Perturbation intensity to use when predicting growth rates
    params: Sigmoidal parameters to use in predictions
    ax: Axes object to use when plotting
    """
    xvals = np.arange(0, perturbations[-1], 0.01)
    yvals = growth_rate(xvals, params)
    yvals = yvals.reshape(np.shape(yvals)[1], )
    ax.plot(xvals, yvals, '-b', linewidth=1.5)
    return

In [None]:
def residuals_growthPairs(avals, true_gr, params1, params2, perturb1, perturb2, penalty):
    """Calculate residuals between pairwise growth rate measurements and predictions.
    The goal is to minimize the regularized RMSD of our predictions. To regularize,
    we add a penalty term based on the absolute values of each parameter in avals.
    Finally, we take the square root of this regularized RMSD, as least_squares technically
    optimizes the squared residual, but we want to minimize the value directly.
    The result of optimization is only numerically (<10**-8) different without the square root.
    Arguments:
    avals: Coupling constants between the genes of interest
    true_gr: True pairwise growth rates
    params1: Sigmoidal growth rate parameters for perturbation 1
    params2: Sigmoidal growth rate parameters for perturbation 2
    perturb1: Perturbation intensities for perturbation 1
    perturb2: Perturbation intensities for perturbation 2
    penalty: Regularization term to penalize the absolute value of coupling constants
    Returns:
    err: Regularized root mean square error between the predicted and measured growth rates
    """
    gpair = pairGrowthRates(avals, perturb1, perturb2, params1, params2)
    err = np.array((true_gr - gpair).flatten().flatten())
    err = err[~np.isnan(err)]
    err = np.sqrt(nanrms(err) + sum([penalty*abs(x) for x in avals]))
    return err


def pairGrowthRates(avals, perturb1, perturb2, p1, p2):
    """Calculate predicted pairwise growth rates following two coupled gene knockdowns.
    Arguments:
    avals: Coupling constants between the genes of interest
    perturb1: First set of perturbation values
    perturb2: Second set of perturbation values
    p1: Growth rate parameters Do and n for perturbation 1
    p2: Growth rate parameters Do and n for perturbation 2
    Returns:
    gr_pair: Predicted pairwise growth rates
    """
    Do1 = p1[0]
    Do2 = p2[0]
    p1eff, p2eff, resid_1, resid_2 = solvePeff(avals, perturb1, perturb2, Do1, Do2)
    gr1, gr2, gr_pair = np.zeros(np.shape(p1eff)), np.zeros(np.shape(p2eff)), np.zeros(np.shape(p1eff))
    for i, per1 in enumerate(p1eff):
        gr1[i, :] = growth_rate(per1, p1)
    for i, per2 in enumerate(p2eff):
        gr2[i, :] = growth_rate(per2, p2)
    for per1 in range(np.shape(p1eff)[0]):
        for per2 in range(np.shape(p1eff)[1]):
            gr_pair[per1, per2] = gr1[per1, per2] * gr2[per1, per2]
    return gr_pair


def solvePeff(avals, perturb1, perturb2, Do_1, Do_2):
    """Solve for the effective intensity of two perturbations given two coupling constants.
    Arguments:
    avals: Coupling constants between the perturbations of interest
    perturb1: First set of perturbation values
    perturb2: Second set of perturbation values
    Do_1: Do parameter for perturbation 1
    Do_2: Do parameter for perturbation 2
    Returns:
    p1eff: 2D array of the first perturbation's relative intensity after accounting for coupling
    p1eff: 2D array of the second perturbation's relative intensity after accounting for coupling
    resids_p1: List of residuals for the first perturbation, returned for troubleshooting and optimization
    resids_p2: List of residuals for the second perturbation, returned for troubleshooting and optimization
    """
    p1eff = np.tile(np.array(perturb1), (len(perturb2), 1)).T
    p2eff = np.tile(np.array(perturb2), (len(perturb1), 1))
    p1update, p2update = np.zeros(np.shape(p1eff)), np.zeros(np.shape(p2eff))
    resids, eps, count = np.inf, 0.01, 0  # Initialize sum of residuals, desired final residual, and count iterator
    resids_p1, resids_p2 = [], []
    # Iteratively reduce the residuals using the update formulas
    # If the desired residual is not reached, exit after 100 iterations
    while resids > eps and count < 100:
        for i, p1 in enumerate(perturb1):
            for j, p2 in enumerate(perturb2):
                p1update[i, j] = p1 / (1 + avals[0]*((p2eff[i, j]/Do_2)/(1 + (p2eff[i, j]/Do_2))))
                p2update[i, j] = p2 / (1 + avals[1]*((p1eff[i, j]/Do_1)/(1 + (p1eff[i, j]/Do_1))))
        resids_p1.append(np.sum(abs(p1eff - p1update)))
        resids_p2.append(np.sum(abs(p2eff - p2update)))
        resids = resids_p1[-1] + resids_p2[-1]
        p1eff = np.copy(p1update)
        p2eff = np.copy(p2update)
        count += 1
    return p1eff, p2eff, resids_p1, resids_p2


def nanrms(x, axis=None):
    """Self-defined root mean square function. Used for convenience and consistency.
    Arguments:
    x: Array of residuals
    axis: In case of an array of arrays, enter the desired axis to calculate across
    Directly returns the RMSD of the residuals provided
    """
    return np.sqrt(np.nanmean(x**2, axis=axis))

In [None]:
def QuadGrowthRates(avals, perturbations, params):
    """Calculate predicted growth rates for four perturbations
    Arguments:
    avals: Coupling constants between all perturbations of interest
    perturbations: All relevant perturbation intensities
    params: Single perturbation-growth rate sigmoidal function parameters
    Returns:
    gr_quad: Predicted fourth-order growth rates
        """
    perturb1, perturb2, perturb3, perturb4 = list(perturbations.values())
    (per1, Do1), (per2, Do2), (per3, Do3), (per4, Do4) = [(params[x], params[x][0]) for x in params]
    perturb_eff, resids = solvePeffQuad(avals, perturb1, perturb2, perturb3, perturb4, Do1, Do2, Do3, Do4)
    [p1eff, p2eff, p3eff, p4eff] = perturb_eff
    gr1 = np.zeros(np.shape(p1eff))
    gr2 = np.zeros(np.shape(p2eff))
    gr3 = np.zeros(np.shape(p3eff))
    gr4 = np.zeros(np.shape(p4eff))
    gr_quad = np.zeros(np.shape(p1eff))
    for p1 in range(np.shape(p1eff)[0]):
        for p2 in range(np.shape(p1eff)[1]):
            for p3 in range(np.shape(p1eff)[2]):
                for p4 in range(np.shape(p1eff)[3]):
                    gr1[p1, p2, p3, p4] = growth_rate(p1eff[p1, p2, p3, p4], per1)
                    gr2[p1, p2, p3, p4] = growth_rate(p2eff[p1, p2, p3, p4], per2)
                    gr3[p1, p2, p3, p4] = growth_rate(p3eff[p1, p2, p3, p4], per3)
                    gr4[p1, p2, p3, p4] = growth_rate(p4eff[p1, p2, p3, p4], per4)
                    gr_quad[p1, p2, p3, p4] = gr1[p1, p2, p3, p4] * gr2[p1, p2, p3, p4] * gr3[p1, p2, p3, p4] \
                                            * gr4[p1, p2, p3, p4]
    return gr_quad


def solvePeffQuad(avals, perturb1, perturb2, perturb3, perturb4, Do_1, Do_2, Do_3, Do_4):
    """Solve for the effective perturbation strength of four perturbations, given twelve coupling constants
    describing all pairwise couplings between them.
    Arguments:
    avals: Dictionary of all coupling constants
    perturbX: Perturbation intensities for perturbation X 
    Do_X: Do parameter for perturbation X
    Returns:
    perturb_eff_list: Four lists of each perturbation's relative intensity after accounting for coupling
    resids_list: List of residuals, returned for troubleshooting and optimization
    """
    avals_list = np.concatenate(list(avals.values()))
    p1eff = np.zeros((len(perturb1), len(perturb2), len(perturb3), len(perturb4)))
    p2eff = np.zeros_like(p1eff)
    p3eff = np.zeros_like(p2eff)
    p4eff = np.zeros_like(p3eff)
    for i in range(len(perturb1)):
        for j in range(len(perturb2)):
            for k in range(len(perturb3)):
                for l in range(len(perturb4)):
                    p1eff[:, j, k, l] = perturb1
                    p2eff[i, :, k, l] = perturb2
                    p3eff[i, j, :, l] = perturb3
                    p2eff[i, j, k, :] = perturb4
    p1update = np.zeros_like(p1eff)
    p2update = np.zeros_like(p2eff)
    p3update = np.zeros_like(p3eff)
    p4update = np.zeros_like(p4eff)
    resids, eps, count = np.inf, 0.01, 0  # Initialize sum of residuals, desired final residual, and count iterator
    resids_p1, resids_p2, resids_p3, resids_p4 = [], [], [], []
    while resids > eps and count < 100:
        for i, p1 in enumerate(perturb1):
            for j, p2 in enumerate(perturb2):
                for k, p3 in enumerate(perturb3):
                    for l, p4 in enumerate(perturb4):
                        p1update[i, j, k, l] = p1 / \
                        ((1 + avals_list[0]*((p2eff[i, j, k, l]/Do_2)/(1 + (p2eff[i, j, k, l]/Do_2))))
                       * (1 + avals_list[2]*((p3eff[i, j, k, l]/Do_3)/(1 + (p3eff[i, j, k, l]/Do_3))))
                       * (1 + avals_list[4]*((p4eff[i, j, k, l]/Do_4)/(1 + (p4eff[i, j, k, l]/Do_4)))))
                        p2update[i, j, k, l] = p2 / \
                        ((1 + avals_list[1]*((p1eff[i, j, k, l]/Do_1)/(1 + (p1eff[i, j, k, l]/Do_1))))
                       * (1 + avals_list[6]*((p3eff[i, j, k, l]/Do_3)/(1 + (p3eff[i, j, k, l]/Do_3))))
                       * (1 + avals_list[8]*((p4eff[i, j, k, l]/Do_4)/(1 + (p4eff[i, j, k, l]/Do_4)))))
                        p3update[i, j, k, l] = p3 / \
                        ((1 + avals_list[3]*((p1eff[i, j, k, l]/Do_1)/(1 + (p1eff[i, j, k, l]/Do_1))))
                       * (1 + avals_list[7]*((p2eff[i, j, k, l]/Do_2)/(1 + (p2eff[i, j, k, l]/Do_2))))
                       * (1 + avals_list[10]*((p4eff[i, j, k, l]/Do_4)/(1 + (p4eff[i, j, k, l]/Do_4)))))
                        p4update[i, j, k, l] = p4 / \
                        ((1 + avals_list[5]*((p1eff[i, j, k, l]/Do_1)/(1 + (p1eff[i, j, k, l]/Do_1))))
                       * (1 + avals_list[9]*((p2eff[i, j, k, l]/Do_2)/(1 + (p2eff[i, j, k, l]/Do_2))))
                       * (1 + avals_list[11]*((p3eff[i, j, k, l]/Do_3)/(1 + (p3eff[i, j, k, l]/Do_3)))))
        resids_p1.append(np.sum(abs(p1eff - p1update)))
        resids_p2.append(np.sum(abs(p2eff - p2update)))
        resids_p3.append(np.sum(abs(p3eff - p3update)))
        resids_p4.append(np.sum(abs(p4eff - p4update)))
        resids = resids_p1[-1] + resids_p2[-1] + resids_p3[-1] + resids_p4[-1]
        p1eff = np.copy(p1update)
        p2eff = np.copy(p2update)
        p3eff = np.copy(p3update)
        p4eff = np.copy(p4update)
        count += 1
    perturb_eff_list = [p1eff, p2eff, p3eff, p4eff]
    resids_list = [resids_p1, resids_p2, resids_p3, resids_p4]
    return perturb_eff_list, resids_list


def subsample(data, frac_drop):
    """Generate training and test sets of perturbation-growth rate data.
    Arguments:
    data: Full grid of growth rate data to be subsampled
    frac_drop: A fraction (between 0 and 1) of the growth rate data to remove for subsampling
    Returns:
    training: Growth rate data after removing the relevant fraction of data
    test: All removed data, to be used in model evaluation
    """
    training = data.copy()
    training[:, 0] = np.nan  # Remove first-order data
    training[0, :] = np.nan  # Remove first-order data
    test = np.full((np.shape(training)), np.nan)
    tot_vals = sum(~np.isnan(training.flatten()))  # Determine total values remaining
    tot_ind = np.arange(len(training.flatten()))
    np.random.shuffle(tot_ind)  # Randomize index order
    tot_drop, i = 0, 0
    # Iterate until dropping sufficient data, without dropping all data points
    while tot_drop < int(np.ceil(tot_vals*frac_drop)) and tot_drop < tot_vals-1:
        drop = tot_ind[i]  # Potential dropped index
        row_ind = int(tot_ind[drop]%(np.shape(training)[0]))
        col_ind = int(np.floor(tot_ind[drop]/(np.shape(training)[0])))
        if ~np.isnan(training[row_ind][col_ind]):  # If not already np.nan, drop it and record that drop
            test[row_ind][col_ind] = data[row_ind][col_ind]  # Add to training set
            training[row_ind][col_ind] = np.nan
            tot_drop += 1
        i += 1
    return training, test

### Fit single-perturbation growth rate curves

As CRISPRi perturbations should monotonically decrease growth rate from 1 to 0, we use a two-parameter sigmoid to fit expression-growth rate function. As media additives could increase growth rate to different maxima, we use a four-parameter sigmoid. We then plot these single perturbation functions individually.

In [None]:
single_gr = {}
single_gr['thymidine'] = growth_rates_mean[:, 0, 0, 0]
single_gr['methionine'] = growth_rates_mean[0, :, 0, 0]
single_gr['folA_KD'] = growth_rates_mean[0, 0, :, 0]
single_gr['thyA_KD'] = growth_rates_mean[0, 0, 0, :]
hill_params = {}
for perturbation in ['thymidine', 'methionine']:
    if perturbation == 'thymidine':
        initial_Do = 1
    elif perturbation == 'methionine':
        initial_Do = 0.5
    hill_params[perturbation] = least_squares(residuals_growth_rate,
                        x0=[initial_Do, 0, max(single_gr[perturbation]), min(single_gr[perturbation])],
                        bounds=[(0, -np.inf, min(single_gr[perturbation])+0.001, min(single_gr[perturbation])),
                                (np.inf, np.inf, max(single_gr[perturbation]), max(single_gr[perturbation])-0.001)],
                        args=(single_gr[perturbation], perturb_dict[perturbation])).x
for perturbation in ['folA_KD', 'thyA_KD']:
    hill_params[perturbation] = least_squares(residuals_growth_rate, x0=[0.5, 1],
                        bounds=[(0, -np.inf), (np.inf, np.inf)], args=(single_gr[perturbation],
                        list(repression_mean_subset[perturbation[0:4]].values()))).x

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
for i, perturbation in enumerate(['thymidine', 'methionine']):
    ax[i].plot(perturb_dict[perturbation], single_gr[perturbation], 'o', color='xkcd:dark gray')
    Hill_growth_curve_fxn(perturb_dict[perturbation], hill_params[perturbation], ax[i])
    ax[i].set_ylim([0, 2])
    if perturbation == 'thymidine':
        units = 'ng/μL'
    elif perturbation == 'methionine':
        units = 'mM'
    ax[i].set_title(perturbation, fontsize=20)
    ax[i].set_xlabel(f"[{perturbation}] {units}", fontsize=14)
    ax[i].set_ylabel('Growth Rate', fontsize=14)
plt.tight_layout()
plt.savefig(f'{figure_path}/FigS9.pdf')
plt.show()
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
for i, perturbation in enumerate(['folA_KD', 'thyA_KD']):
    ax[i].plot(repression_mean_subset[perturbation[0:4]].values(), single_gr[perturbation], "ok")
    Hill_growth_curve_fxn(list(repression_mean_subset[perturbation[0:4]].values()), hill_params[perturbation], ax[i])
    ax[i].set_ylim([0, 2])
    ax[i].set_title(perturbation, fontsize=20)
    ax[i].set_xlabel("Repression", fontsize=14)
    ax[i].set_ylabel("Growth Rate", fontsize=14)
plt.tight_layout()
plt.show()

### Fit coupling constants and calculate predicted growth rates

We fit coupling constants to pairwise perturbation-growth rate data, then use these coupling constants and single-gene sigmoid parameters to generate predicted growth rates for each perturbation combination -- up to fourth order. We also predict growth rates using a Null model without coupling and a multiplicative (Bliss) model that considers only perturbations directly investigated and cannot make continuous predictions.

In [None]:
pair_gr = {}
pair_gr['thymidine-methionine'] = growth_rates_mean[:, :, 0, 0]
pair_gr['thymidine-folA_KD'] = growth_rates_mean[:, 0, :, 0]
pair_gr['thymidine-thyA_KD'] = growth_rates_mean[:, 0, 0, :]
pair_gr['methionine-folA_KD'] = growth_rates_mean[0, :, :, 0]
pair_gr['methionine-thyA_KD'] = growth_rates_mean[0, :, 0, :]
pair_gr['folA_KD-thyA_KD'] = growth_rates_mean[0, 0, :, :]
pair_avals = {}
for (per1, per2) in itertools.combinations(perturb_list, 2):
    perturb_pair = f'{per1}-{per2}'
    pair_avals[perturb_pair] = least_squares(residuals_growthPairs, x0=[0, 0], bounds=[(-np.inf, -np.inf),
                                            (np.inf, np.inf)], args=(pair_gr[perturb_pair], hill_params[per1],
                                             hill_params[per2], perturb_dict[per1], perturb_dict[per2], 10**-1.25)).x
avals_matrix = np.full([len(perturb_list), len(perturb_list)], np.nan)
for i, per1 in enumerate(perturb_list):
    for j, per2 in enumerate(perturb_list[i+1:]):
        avals_matrix[i, j+i+1] = pair_avals[f'{per1}-{per2}'][0]
        avals_matrix[j+i+1, i] = pair_avals[f'{per1}-{per2}'][1]
fig, ax = plt.subplots(figsize=(7, 7))
cb = ax.imshow(avals_matrix, cmap='PuOr_r', vmin=-0.75, vmax=0.75)
ax.set_yticks(np.arange(len(perturb_list)), labels=perturb_list)
ax.set_xticks(np.arange(len(perturb_list)), labels=perturb_list, rotation=90)
ax.set_facecolor("darkgrey")
ax.spines[['top', 'right']].set_visible(True)
plt.colorbar(cb)
plt.tight_layout()
plt.show()
print(f"Thymidine-thyA_KD coupling: {pair_avals['thymidine-thyA_KD']}")

In [None]:
gr_epistatic_pred = QuadGrowthRates(pair_avals, perturb_dict, hill_params)
gr_null_pred = QuadGrowthRates({perturb_pair:[0, 0] for perturb_pair in pair_avals}, perturb_dict, hill_params)
bliss_quad_gr = np.empty_like(growth_rates_mean)
bliss_quad_gr[:] = np.nan
for i, gr1 in enumerate(growth_rates_mean[:, 0, 0, 0]):
    for j, gr2 in enumerate(growth_rates_mean[0, :, 0, 0]):
        for k, gr3 in enumerate(growth_rates_mean[0, 0, :, 0]):
            for l, gr4 in enumerate(growth_rates_mean[0, 0, 0, :]):
                bliss_quad_gr[i, j, k, l] = gr1 * gr2 * gr3 * gr4
bliss_RMSD = nanrms(growth_rates_mean.flatten() - bliss_quad_gr.flatten())
null_RMSD = nanrms(growth_rates_mean.flatten() - gr_null_pred.flatten())
epistatic_RMSD = nanrms(growth_rates_mean.flatten() - gr_epistatic_pred.flatten())
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(growth_rates_mean.flatten(), bliss_quad_gr.flatten(), 'go', markersize=3, alpha=0.15)
ax[0].text(0, 2, f'RMSD = {np.round(bliss_RMSD, 4)}', fontsize=20)
ax[0].set_title('Bliss Model', fontsize=20)
ax[1].plot(growth_rates_mean.flatten(), gr_null_pred.flatten(), 'ko', markersize=3, alpha=0.15)
ax[1].text(0, 1.5, f'RMSD = {np.round(null_RMSD, 4)}', fontsize=20)
ax[1].set_title('Null Model - No Coupling', fontsize=20)
ax[2].plot(growth_rates_mean.flatten(), gr_epistatic_pred.flatten(), 'bo', markersize=3, alpha=0.15)
ax[2].text(0.75, 0.1, f'RMSD = {np.round(epistatic_RMSD, 4)}', fontsize=20)
ax[2].set_title('Epistatic Model', fontsize=20)
for i in [0, 1, 2]:
    ax[i].plot([0, 1.2], [0, 1.2], 'r--')
    ax[i].set_xlabel('Measured Growth Rate', fontsize=16)
    ax[i].set_ylabel('Predicted Growth Rate', fontsize=16)
plt.tight_layout()
plt.show()

### Pull out second-order perturbation data and high-order (third and fourth) perturbation data

In [None]:
true_2nd_order_gr = np.concatenate((growth_rates_mean[1:, 1:, 0, 0].flatten(),
                                    growth_rates_mean[1:, 0, 1:, 0].flatten(),
                                    growth_rates_mean[1:, 0, 0, 1:].flatten(),
                                    growth_rates_mean[0, 1:, 1:, 0].flatten(),
                                    growth_rates_mean[0, 1:, 0, 1:].flatten(),
                                    growth_rates_mean[0, 0, 1:, 1:].flatten()))
true_high_order_gr = np.concatenate((growth_rates_mean[1:, 1:, 1:, 0].flatten(),
                                     growth_rates_mean[1:, 0, 1:, 1:].flatten(),
                                    growth_rates_mean[1:, 1:, 0, 1:].flatten(),
                                     growth_rates_mean[0, 1:, 1:, 1:].flatten(),
                                    growth_rates_mean[1:, 1:, 1:, 1:].flatten()))
epistatic_2nd_order_pred = np.concatenate((gr_epistatic_pred[1:, 1:, 0, 0].flatten(),
                                           gr_epistatic_pred[1:, 0, 1:, 0].flatten(),
                                           gr_epistatic_pred[1:, 0, 0, 1:].flatten(),
                                           gr_epistatic_pred[0, 1:, 1:, 0].flatten(),
                                           gr_epistatic_pred[0, 1:, 0, 1:].flatten(),
                                           gr_epistatic_pred[0, 0, 1:, 1:].flatten()))
epistatic_high_order_pred = np.concatenate((gr_epistatic_pred[1:, 1:, 1:, 0].flatten(),
                                            gr_epistatic_pred[1:, 0, 1:, 1:].flatten(),
                                            gr_epistatic_pred[1:, 1:, 0, 1:].flatten(),
                                            gr_epistatic_pred[0, 1:, 1:, 1:].flatten(),
                                            gr_epistatic_pred[1:, 1:, 1:, 1:].flatten()))
null_high_order_pred = np.concatenate((gr_null_pred[1:, 1:, 1:, 0].flatten(), gr_null_pred[1:, 0, 1:, 1:].flatten(),
                                       gr_null_pred[1:, 1:, 0, 1:].flatten(), gr_null_pred[0, 1:, 1:, 1:].flatten(),
                                       gr_null_pred[1:, 1:, 1:, 1:].flatten()))
null_2nd_order_pred = np.concatenate((gr_null_pred[1:, 1:, 0, 0].flatten(), gr_null_pred[1:, 0, 1:, 0].flatten(),
                                      gr_null_pred[1:, 0, 0, 1:].flatten(), gr_null_pred[0, 1:, 1:, 0].flatten(),
                                      gr_null_pred[0, 1:, 0, 1:].flatten(), gr_null_pred[0, 0, 1:, 1:].flatten()))
print(f'Epistatic, Pairwise:   {np.round(nanrms(epistatic_2nd_order_pred - true_2nd_order_gr), 5)}')
print(f'Null, Pairwise:        {np.round(nanrms(null_2nd_order_pred - true_2nd_order_gr), 5)}')
print(f'Epistatic, High Order: {np.round(nanrms(epistatic_high_order_pred - true_high_order_gr), 5)}')
print(f'Null, High Order:      {np.round(nanrms(null_high_order_pred - true_high_order_gr), 5)}')

### Calculate and visualize prediction error within different perturbation regimes.

In [None]:
rmsd_dict = {'Epistatic':{}, 'Null':{}}
for model, data in zip(rmsd_dict, [gr_epistatic_pred, gr_null_pred]):
    rmsd_dict[model]['thymidine-methionine'] = nanrms(data[1:, 1:, 0, 0].flatten() \
                                                    - growth_rates_mean[1:, 1:, 0, 0].flatten())
    rmsd_dict[model]['thymidine-folA_KD'] = nanrms(data[1:, 0, 1:, 0].flatten() \
                                                 - growth_rates_mean[1:, 0, 1:, 0].flatten())
    rmsd_dict[model]['thymidine-thyA_KD'] = nanrms(data[1:, 0, 0, 1:].flatten() \
                                                 - growth_rates_mean[1:, 0, 0, 1:].flatten())
    rmsd_dict[model]['methionine-folA_KD'] = nanrms(data[0, 1:, 1:, 0].flatten() \
                                                  - growth_rates_mean[0, 1:, 1:, 0].flatten())
    rmsd_dict[model]['methionine-thyA_KD'] = nanrms(data[0, 1:, 0, 1:].flatten() \
                                                  - growth_rates_mean[0, 1:, 0, 1:].flatten())
    rmsd_dict[model]['folA_KD-thyA_KD'] = nanrms(data[0, 0, 1:, 1:].flatten() \
                                               - growth_rates_mean[0, 0, 1:, 1:].flatten())
    rmsd_dict[model]['thymidine-methionine-folA_KD'] = nanrms(data[1:, 1:, 1:, 0].flatten() \
                                                            - growth_rates_mean[1:, 1:, 1:, 0].flatten())
    rmsd_dict[model]['thymidine-methionine-thyA_KD'] = nanrms(data[1:, 1:, 0, 1:].flatten() \
                                                            - growth_rates_mean[1:, 1:, 0, 1:].flatten())
    rmsd_dict[model]['thymidine-folA_KD-thyA_KD'] = nanrms(data[1:, 0, 1:, 1:].flatten() \
                                                         - growth_rates_mean[1:, 0, 1:, 1:].flatten())
    rmsd_dict[model]['methionine-folA_KD-thyA_KD'] = nanrms(data[0, 1:, 1:, 1:].flatten() \
                                                          - growth_rates_mean[0, 1:, 1:, 1:].flatten())
    rmsd_dict[model]['thymidine-methionine-folA_KD-thyA_KD'] = nanrms(data[1:, 1:, 1:, 1:].flatten() \
                                                                    - growth_rates_mean[1:, 1:, 1:, 1:].flatten())

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
for i, [epistatic, null] in enumerate(zip(list(rmsd_dict['Epistatic'].values())[::-1],
                                          list(rmsd_dict['Null'].values())[::-1])):
    ax.scatter(null, i, color='b', s=75, ec='xkcd:dark gray')
    ax.scatter(epistatic, i, color='r', s=75, ec='xkcd:dark gray')
ax.tick_params(axis='both', labelsize=14)
ax.set_yticks(range(len(list(rmsd_dict['Epistatic'].values()))))
ax.set_yticklabels(['thy-met', 'thy-$folA$', 'thy-$thyA$', 'met-$folA$', 'met-$thyA$', '$folA$-$thyA$',
                    'thy-met-$folA$', 'thy-met-$thyA$', 'thy-$folA$-$thyA$', 'met-$folA$-$thyA$',
                    'thy-met-$folA$-$thyA$'][::-1])
ax.set_xticks([0.25, 0.5, 0.75])
ax.set_xlabel('Model RMSD', fontsize=18)
ax.set_xlim([0, 0.75])
ax.text(0.45, 9.7, 'Coupling', color='r', fontsize=18)
ax.text(0.45, 9, 'Null', color='b', fontsize=18)
plt.tight_layout()
plt.savefig(f'{figure_path}/Fig6D.pdf', transparent=True)
plt.show()

### Effects of regularization and subsampling on the model.

We investigated the effects of changing subsampling and regularization parameters by repeating the subsampling procedure 100 times with many different internal parameters. This code should be run on a computing cluster if it is to be repeated. We have saved the output of our repeated iterations, and summarize results below. Model sensitivity was comparable to the model iteration trained on pairwise CRISPRi data, so we used the same subsampling and regularization parameters.

In [None]:
with open(f'{optimization_path}/220815_reg_opt_rmsd.pickle', 'rb') as handle:
    regularization_rmsd = pickle.load(handle)
with open(f'{optimization_path}/220815_sub_opt_rmsd.pickle', 'rb') as handle:
    subsampling_rmsd = pickle.load(handle)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
for i, sub in enumerate(regularization_rmsd):
    if sub == 0.8:
        ax.scatter(list(regularization_rmsd[sub].keys()), [np.mean(x) for x in regularization_rmsd[sub].values()],
                   s=90, color=['b', 'xkcd:orange'][i], edgecolor='xkcd:dark gray', label='20%')
    elif sub == 0.9:
        ax.scatter(list(regularization_rmsd[sub].keys()), [np.mean(x) for x in regularization_rmsd[sub].values()],
                   s=90, color=['b', 'xkcd:orange'][i], edgecolor='xkcd:dark gray', label='10%')
ax.set_xscale('log')
ax.set_xlabel('Regularization Weight', fontsize=20)
ax.set_ylabel('Model RMSD', fontsize=20)
ax.tick_params(axis='both', labelsize=18)
ax.legend()
plt.tight_layout()
plt.savefig(f'{figure_path}/FigS12.pdf')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
RMSD_labels = list(subsampling_rmsd.keys())
RMSD_toPlot = list(np.array(subsampling_rmsd[sub]['pair']) for sub in RMSD_labels)
sns.violinplot(data=RMSD_toPlot, scale = 'width')
ax.set_xticklabels(RMSD_labels)
ax.axhline(nanrms(null_2nd_order_pred - true_2nd_order_gr), ls='--', color='r')
ax.axhline(nanrms(epistatic_2nd_order_pred - true_2nd_order_gr), ls='--', color='b')
ax.set_ylabel('Subsampled RMSD', fontsize=18)
ax.set_xlabel('Fraction Dropped', fontsize=18)
ax.set_title('Pairwise Predictions', fontsize=18)
plt.text(0.8, 0.093, "Full Model RMSD", horizontalalignment='center', fontsize=14, color='b')
plt.text(0.8, 0.3, "Null Model RMSD", horizontalalignment='center', fontsize=14, color='r')
plt.tight_layout()
plt.savefig(f'{figure_path}/FigS11A.pdf')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
RMSD_labels = list(subsampling_rmsd.keys())
RMSD_toPlot = list(np.array(subsampling_rmsd[sub]['high']) for sub in RMSD_labels)
sns.violinplot(data=RMSD_toPlot, scale='width')
ax.set_xticklabels(RMSD_labels)
ax.axhline(nanrms(null_high_order_pred - true_high_order_gr), ls='--', color='r')
ax.axhline(np.mean(subsampling_rmsd[0]['high']), ls='--', color='b')
ax.set_ylabel('Subsampled RMSD', fontsize=18)
ax.set_xlabel('Fraction Dropped', fontsize=18)
ax.set_title('Third- and Fourth Order Predictions', fontsize=18)
plt.text(0.8, 0.235, "Full Model RMSD", horizontalalignment='center', fontsize=14, color='b')
plt.text(0.8, 0.53, "Null Model RMSD", horizontalalignment='center', fontsize=14, color='r')
plt.tight_layout()
plt.savefig(f'{figure_path}/FigS11B.pdf')
plt.show()

### The following code should (probably) not be run locally.

In [None]:
"""
num_its = 100
sub_list = [0.8, 0.9]
reg_list = [10**-4, 10**-3.5, 10**-3, 10**-2.5, 10**-2, 10**-1.75, 10**-1.5, 10**-1.25, 10**-1, 10**-0.5, 10**0,
            10**0.5, 10**1, 10**1.5, 10**2]
reg_rmsd_store = {}
for sub in sub_list:
    reg_rmsd_store[sub] = {reg:[] for reg in reg_list}
    print(f'Subsampling: {sub}')
    for i in range(num_its):
        print(f'Iteration {i+1}, ', end='')
        reg_gr = {}
        reg_gr['thymidine-methionine'] = subsample(growth_rates_mean[:, :, 0, 0], sub)
        reg_gr['thymidine-folA_KD'] = subsample(growth_rates_mean[:, 0, :, 0], sub)
        reg_gr['thymidine-thyA_KD'] = subsample(growth_rates_mean[:, 0, 0, :], sub)
        reg_gr['methionine-folA_KD'] = subsample(growth_rates_mean[0, :, :, 0], sub)
        reg_gr['methionine-thyA_KD'] = subsample(growth_rates_mean[0, :, 0, :], sub)
        reg_gr['folA_KD-thyA_KD'] = subsample(growth_rates_mean[0, 0, :, :], sub)
        for reg in reg_list:
            pair_avals_reg = {}
            for (per1, per2) in itertools.combinations(perturb_list, 2):
                perturb_pair = f'{per1}-{per2}'
                pair_avals_reg[perturb_pair] = least_squares(residuals_growthPairs, x0=[0, 0],
                                                    bounds=[(-np.inf, -np.inf), (np.inf, np.inf)],
                                                    args=(reg_gr[perturb_pair][0], hill_params[per1],
                                                    hill_params[per2], perturb_dict[per1], perturb_dict[per2], reg)).x
            reg_gr_epistatic_pred = QuadGrowthRates(pair_avals_reg, perturb_dict, hill_params)
            reg_epistatic_high_order_pred = np.concatenate((reg_gr_epistatic_pred[1:, 1:, 1:, 0].flatten(),
                                             reg_gr_epistatic_pred[1:, 0, 1:, 1:].flatten(),
                                            reg_gr_epistatic_pred[1:, 1:, 0, 1:].flatten(),
                                            reg_gr_epistatic_pred[0, 1:, 1:, 1:].flatten(),
                                       reg_gr_epistatic_pred[1:, 1:, 1:, 1:].flatten()))
            reg_rmsd_store[sub][reg].append(nanrms(reg_epistatic_high_order_pred - true_high_order_gr))
with open(f"{optimization_path}/{date}_reg_opt_rmsd.pickle", 'wb') as handle:
    pickle.dump(reg_rmsd_store, handle, protocol=pickle.HIGHEST_PROTOCOL)
"""
pass

In [None]:
"""
num_its = 100
sub_list = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.98]
sub_rmsd_store = {}
for sub in sub_list:
    sub_rmsd_store[sub] = {'pair':[], 'high':[]}
    print(f'Subsampling: {sub}')
    for i in range(num_its):
        if (i+1) % 10 == 0:
            print(f'Iteration {i+1}, ', end='')
        sub_gr = {}
        sub_gr['thymidine-methionine'] = subsample(growth_rates_mean[:, :, 0, 0], sub)
        sub_gr['thymidine-folA_KD'] = subsample(growth_rates_mean[:, 0, :, 0], sub)
        sub_gr['thymidine-thyA_KD'] = subsample(growth_rates_mean[:, 0, 0, :], sub)
        sub_gr['methionine-folA_KD'] = subsample(growth_rates_mean[0, :, :, 0], sub)
        sub_gr['methionine-thyA_KD'] = subsample(growth_rates_mean[0, :, 0, :], sub)
        sub_gr['folA_KD-thyA_KD'] = subsample(growth_rates_mean[0, 0, :, :], sub)
        pair_avals_sub = {}
        for (per1, per2) in itertools.combinations(perturb_list, 2):
            perturb_pair = f'{per1}-{per2}'
            pair_avals_sub[perturb_pair] = least_squares(residuals_growthPairs, x0=[0, 0],
                                            bounds=[(-np.inf, -np.inf), (np.inf, np.inf)],
                                            args=(sub_gr[perturb_pair][0], hill_params[per1],
                                            hill_params[per2], perturb_dict[per1], perturb_dict[per2], 10**-1.25)).x
        sub_gr_epistatic_pred = QuadGrowthRates(pair_avals_sub, perturb_dict, hill_params)
        epistatic_high_order_pred_sub = np.concatenate((sub_gr_epistatic_pred[1:, 1:, 1:, 0].flatten(),
                                         sub_gr_epistatic_pred[1:, 0, 1:, 1:].flatten(),
                                        sub_gr_epistatic_pred[1:, 1:, 0, 1:].flatten(),
                                        sub_gr_epistatic_pred[0, 1:, 1:, 1:].flatten(),
                                   sub_gr_epistatic_pred[1:, 1:, 1:, 1:].flatten()))
        epistatic_2nd_order_pred_sub = np.concatenate((sub_gr_epistatic_pred[1:, 1:, 0, 0].flatten(),
                                     sub_gr_epistatic_pred[1:, 0, 1:, 0].flatten(),
                                    sub_gr_epistatic_pred[1:, 0, 0, 1:].flatten(),
                                    sub_gr_epistatic_pred[0, 1:, 1:, 0].flatten(),
                                     sub_gr_epistatic_pred[0, 1:, 0, 1:].flatten(),
                                    sub_gr_epistatic_pred[0, 0, 1:, 1:].flatten()))
        test_2nd_order = np.concatenate(([sub_gr[perturb_pair][1][1:, 1:].flatten() for perturb_pair in sub_gr]))
        sub_rmsd_store[sub]['high'].append(nanrms(epistatic_high_order_pred_sub - true_high_order_gr))
        if sub == 0:  # No subsampling, use the full model
            sub_rmsd_store[sub]['pair'].append(nanrms(epistatic_2nd_order_pred - true_2nd_order_gr))
            break
        else:
            sub_rmsd_store[sub]['pair'].append(nanrms(epistatic_2nd_order_pred_sub - test_2nd_order))
with open(f"{optimization_path}/{date}_sub_opt_rmsd.pickle", 'wb') as handle:
    pickle.dump(sub_rmsd_store, handle, protocol=pickle.HIGHEST_PROTOCOL)
"""
pass

### Export data

In [None]:
table_s7 = pd.DataFrame(np.full((11, 8), np.nan), columns=['Perturbation 1', 'Perturbation 2', 'Perturbation 3',
                                                          'Perturbation 4', 'aij', 'aji', 'Model RMSD', 'Null RMSD'])
for i, perturbation in enumerate(rmsd_dict['Epistatic']):
    per_list = perturbation.split('-')
    if len(per_list) == 2:
        [aval_1, aval_2] = pair_avals[perturbation]
    else:
        aval_1, aval_2 = 'N/A', 'N/A'
    while len(per_list) < 4:
        per_list.append('None')
    table_s7.loc[i] = [per_list[0], per_list[1], per_list[2], per_list[3], aval_1, aval_2,
                       rmsd_dict['Epistatic'][perturbation], rmsd_dict['Null'][perturbation]]
with pd.ExcelWriter(f'Supplementary_Tables.xlsx', mode='a', if_sheet_exists='replace') as writer:  
    table_s7.to_excel(writer, sheet_name='Table S7')