In [None]:
import pandas as pd, numpy as np, random as rd
from scipy.optimize import minimize
from scipy.stats import pearsonr
from scipy.special import binom
from sympy import fwht, ifwht
import src.landscape_tools as lstoo
import matplotlib as mpl, matplotlib.pyplot as plt

# LaTeX font for plots
plt.rcParams.update({
    'font.family': 'serif',  # use serif/main font for text elements
    'text.usetex': True,     # use inline math for ticks
})

In [None]:
# length of sequence and number of spin states
L, q = 10, 2
# number/name of all sites
sites = [26, 27, 28, 31, 35, 50, 53, 56, 57, 58]
# names of key mutations (X=V/L/I)
muts = ['G26E', 'F27X', 'T28I', 'S31R', 'S35T', 'V50L', 'S53P', 'S56T', 'T57A', 'Y58F']

In [None]:
# mapping site number to vector index
pos2i = {pos: i for i, pos in enumerate(sites)}

# load sequence count data
data = pd.read_csv('data/COV107_mutlib_fit_filtered_exp.tsv', sep='\t')

# convert column 'mut' in the data file to spin chains of 0/1s
seqs = []
for x in range(len(data)):
    seq = [0 for i in range(L)]
    for mut in data['mut'][x].split('-'):
        if mut != 'WT':
            seq[pos2i[int(mut[1:-1])]] = 1
    seqs.append(tuple(seq))
data['mut'] = seqs

# group by sequence and sum sequence counts & rename and drop columns
data.drop(columns=['mutclass','exp1_enrich','exp2_enrich'], inplace=True)
data.rename(columns={'mut':'seq', 'input_Count':'ni', 'exp1_count':'no1', 'exp2_count':'no2'}, inplace=True)
data = data.groupby('seq').sum().reset_index()

# compute empirical enrichments as log-enrichments
data['F1_emp'], data['F2_emp'] = np.log((1.+data.no1) / (1.+data.ni)), np.log((1.+data.no2) / (1.+data.ni))

# enumerate all possible sequences and sort dataframe
seqs = lstoo.seqlist(q=q, L=L)
data['seq'] = pd.Categorical(data['seq'], categories=seqs, ordered=True)
data.sort_values('seq').reset_index(drop=True)

# subtract offset to have germline at zero fitness
data['F1_emp'] -= data.F1_emp.iloc[0]
data['F2_emp'] -= data.F2_emp.iloc[0]

## cross-validation for specific epistasis models using maximum likelihood inference

In [None]:
# random partition of the q**L genotypes and their count data into batches
# set seed for reproducible batches
rd.seed(1)
# number of batches
nbatches = 10
# do the batch assignment
data['batch'] = [rd.randint(1, nbatches) for _ in range(q**L)]

In [None]:
# replicas
rs = [1, 2]

# maximum max order of sequence site interactions in specific epistasis model
order_max = 5

# log-likelihood function
def loglike(X, ni, no):
    F = M.dot(X)
    return -( (no*F).sum() - no.sum()*np.log((ni*np.exp(F)).sum()) ) / q**L

# iterate over experimental replicates
for r in rs:

    # dictionary to collect inferred model parameters for all orders and batches
    ress = {}
    
    # iterate over max orders of site interactions
    for order in range(1, 5+1):
        # iterate over test batches (to be removed for inference step)
        for b in range(1, nbatches+1):

            # remove test batch from sequence count data
            data_batch = data[data.batch != b]
            print(f' doing replicate {r} order {order} batch {b} no. training sequences: {len(data_batch)}')

            # matrix in F=M.X where F = vector of fitness values, X = vector of model parameters
            # restricted to training set genotypes
            M = lstoo.mkM(q=q, L=L, order=order, seqs=data_batch.seq)

            # model inference using maximum likelihood
            res = minimize(loglike,
                           x0 = np.zeros(M.shape[1]),
                           args = (data_batch.ni, data_batch[f'no{str(r)}']),
                           method = 'BFGS', tol = 1e-3,
                          )#callback = lambda y: print(loglike(y, data_batch.ni, data_batch[f'no{str(r)}'])))
            
            # collect model parameters
            ress[(order,b)] = res.x
    
    # save model parameters to external file
    np.save(f'output/1c_repl{str(r)}_crossvalidation.npy', ress)

## plot cross-validation data

In [None]:
# collector list for cross-validation results
rsqs = []

# iterate over experimental replicates
for r in [1,2]:

    # load cross-validation data
    ress = np.load(f'output/1c_repl{str(r)}_crossvalidation.npy', allow_pickle=True).item()

    # iterate over model orders
    for order in range(1, order_max+1):

        # get matrix for F = M.X
        M = lstoo.mkM(q=q, L=L, order=order)

        # iterate over test batches:
        for b in range(1, nbatches+1):

            # compute fitness values of the model fitted to training batches
            F_model = M.dot(ress[(order,b)])
            # compute explained variance on test batch
            rsq = pearsonr(F_model[data.batch==b], data.F1_emp[data.batch==b])[0]**2
            # collect Rsquare values
            rsqs.append([r, order, rsq])

# create dataframe of Rsquare values by replicate experiment and model order
data_crossvalid = pd.DataFrame(rsqs, columns=['replicate','order','Rsq'])

In [None]:
# take mean and standard deviation over the Rsquare values for each replicate and order
data_crossvalid = data_crossvalid.groupby(['replicate','order'])\
                                 .agg(['mean','std']).reset_index()

In [None]:
# create figure
fig, ax = plt.subplots(figsize=(3.2, 2.8))

# plot cross-validation data
for r, ls in zip([1,2], ['dashed', 'dotted']):
    to_plot = data_crossvalid[data_crossvalid.replicate==r].sort_values('order')
    ax.errorbar(to_plot.order, to_plot.Rsq['mean'], yerr=to_plot.Rsq['std'], c='k', ls=ls)

# layout
ax.set_xticks(range(1, 5+1))
ax.tick_params(labelsize=15)
ax.set_xlabel(r'interaction order $p$', fontsize=15)
ax.set_ylabel(r'$R^2$ on test data', fontsize=15)
ax.legend([f'replicate {r}' for r in [1,2]], fontsize=15, loc='lower center')
ax.grid(zorder=-1)

# save plot
plt.savefig('output/s3a_1.jpg', bbox_inches='tight', pad_inches=0.02, dpi=300)
plt.savefig('output/s3a_1.pdf', bbox_inches='tight', pad_inches=0.02)
plt.show()

## bandpass filter for specific epistasis models using Walsh-Hadamard transform

In [None]:
# load empirical fitness values and those from specific epistasis model using maximum-likelihood method from external file
data = pd.read_csv('output/1c_fitness_specific.csv')
data['seq'] = data.seq.apply(lambda x: tuple([int(a) for a in x[1:-1].split(', ')]))

In [None]:
# compute Walsh-Hadamard transform of empirical fitness values
fhats = [fhat/q**L for fhat in fwht(data.F1_emp)]
# compute Walsh-Hadamard transform of fitness values in specific epistasis model inferred from maximum likelihood
fhats_model = [fhat/q**L for fhat in fwht(data.F1_model)]

# Hadamard spectrum for empirical fitness values
betas = [np.mean([f**2 for s, f in zip(data.seq, fhats) if sum(s) == n]) for n in range(L+1)]
# Hadamard spectrum for fitness values in specific epistasis model inferred from maximum likelihood
betas_model = [np.mean([f**2 for s, f in zip(data.seq, fhats_model) if sum(s) == n]) for n in range(L+1)]

In [None]:
# create figure
fig, ax = plt.subplots(figsize=[3.2,2.8])

# plot Hadamard spectra
ax.plot(range(L+1), betas, c='k')
ax.plot(range(L+1), betas_model, c='k', ls='dashed')

# layout
ax.set_yscale('log')
ax.set_ylim([1e-5, 1e0])
ax.set_xticks(range(L+1))
ax.set_yticks([10**(-yyy) for yyy in range(5+1)])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'interaction order $p$', fontsize=15)
ax.set_ylabel(r'Hadamard amplitude $\beta_p$', fontsize=15)
ax.grid(axis='y')
ax.legend([r'$F_\mathrm{emp}$', r'$F_\mathrm{model}$'], fontsize=15, loc='upper right')

# save plot
plt.savefig('output/s3a_2.jpg', bbox_inches='tight', pad_inches=0.02, dpi=300)
plt.savefig('output/s3a_2.pdf', bbox_inches='tight', pad_inches=0.02)
plt.show()

## "fit" specific epistasis model using Walsh-Hadamard transform

In [None]:
# set all Hadamard coefficients beyond maximum interaction order to zero
order = 3
fhats_trunc = np.array(fhats)
fhats_trunc[int(sum(binom(L, p) for p in range(order+1))):] = 0.

# obtain fitness values from inverse Walsh-Hadamard transform on the truncated Hadamard spectrum
data_walsh = data.copy(deep=True)[['ni','no1','no2','F1_emp','F2_emp']]
data_walsh['F1_model'] = [q**L*f for f in ifwht(fhats_trunc)]

In [None]:
# save fitness dataframe to external file
data_walsh.to_csv('output/1c_fitness_walsh.csv', index=False)