# figure 1B

In [None]:
import pandas as pd, numpy as np, mavenn
import src.landscape_tools as lstoo, src.plot_tools as plottoo
import matplotlib as mpl, matplotlib.pyplot as plt, matplotlib.patches as mpatches
from mpl_toolkits.axes_grid1 import make_axes_locatable

# LaTeX font for plots
plt.rcParams.update({
    'font.family': 'serif',  # use serif/main font for text elements
    'text.usetex': True,     # use inline math for ticks
})

In [None]:
# length of sequence and number of spin states
L, q = 10, 2
# number/name of all sites
sites = [26, 27, 28, 31, 35, 50, 53, 56, 57, 58]
# names of key mutations (X=V/L/I)
muts = ['G26E', 'F27X', 'T28I', 'S31R', 'S35T', 'V50L', 'S53P', 'S56T', 'T57A', 'Y58F']

In [None]:
# mapping site number to vector index
pos2i = {pos: i for i, pos in enumerate(sites)}

# load sequence count data
data = pd.read_csv('data/COV107_mutlib_fit_filtered_exp.tsv', sep='\t')

# convert column 'mut' in the data file to spin chains of 0/1s
seqs = []
for x in range(len(data)):
    seq = [0 for i in range(L)]
    for mut in data['mut'][x].split('-'):
        if mut != 'WT':
            seq[pos2i[int(mut[1:-1])]] = 1
    seqs.append(tuple(seq))
data['mut'] = seqs

# group by sequence and sum sequence counts & rename and drop columns
data.drop(columns=['mutclass','exp1_enrich','exp2_enrich'], inplace=True)
data.rename(columns={'mut':'seq', 'input_Count':'ni', 'exp1_count':'no1', 'exp2_count':'no2'}, inplace=True)
data = data.groupby('seq').sum().reset_index()

# compute empirical enrichments as log-enrichments
data['F1_emp'], data['F2_emp'] = np.log((1.+data.no1) / (1.+data.ni)), np.log((1.+data.no2) / (1.+data.ni))

# enumerate all possible sequences and sort dataframe
seqs = lstoo.seqlist(q=q, L=L)
data['seq'] = pd.Categorical(data['seq'], categories=seqs, ordered=True)
data.sort_values('seq').reset_index(drop=True)

# subtract offset to have germline at zero fitness
data['F1_emp'] -= data.F1_emp.iloc[0]
data['F2_emp'] -= data.F2_emp.iloc[0]

## fit global epistasis model

In [None]:
# make Pandas dataframe with training/test data (here: all data is training data)
np.random.seed(1)
seqs_str = [''.join([str(a) for a in s]) for s in data.seq]
data_df = pd.DataFrame({'x': seqs_str, 'y': data.F1_emp,
                        'set': [np.random.choice(['training', 'test'], p=[1., 0.]) for s in seqs]})

# separate test from data_df (here: all data is training data)
ix_test = data_df['set'] == 'test'
test_df = data_df[ix_test].reset_index(drop=True)
print(f'number of test sequences: {len(test_df):,}')

# remove test data from data_df
data_df = data_df[~ix_test].reset_index(drop=True)
print(f'sequences to be used for training + validation: {len(data_df):,}')

In [None]:
# define model (boolean alphabet, additive latent phenotype)
model = mavenn.Model(L=L,
                     alphabet=['0', '1'],
                     gpmap_type='additive',
                     regression_type='GE',
                     ge_noise_model_type='Gaussian',
                     ge_heteroskedasticity_order=0,
                     ge_nonlinearity_hidden_nodes=1)

# set training data
model.set_data(x=data_df.x,
               y=data_df.y,
               validation_frac = .2,
               shuffle=True)

# fit model to data
history = model.fit(learning_rate=.005,
                    epochs=1000,
                    batch_size=q**L,
                    early_stopping=False,
                    linear_initialization=False)

# save model
model.save('output/1b_repl1')

## plot global epistasis model

In [None]:
# load fitted global epistasis model
model = mavenn.load('output/1b_repl1')

# extract local fields h_i
theta = model.get_theta()
h = theta['theta_lc'][:,1] - theta['theta_lc'][:,0]

# predict latent phenotype values (phi) & fitness g(phi) for all sequences
phis = model.x_to_phi(seqs_str)
yhats = model.x_to_yhat(seqs_str)
data['F1_model'] = yhats

# set phi lims and create grid in phi space, compute fitness for each phi gridpoint
phi_lim = [min(phis)-.5, max(phis)+.5]
phi_grid = np.linspace(phi_lim[0], phi_lim[1], 1000)
yhat_grid = model.phi_to_yhat(phi_grid)

In [None]:
# save fitness dataframe to external file
data.to_csv('output/1c_fitness_global.csv')

In [None]:
# create figure
fig, ax = plt.subplots(figsize=(1., 3.2))

# color map
hmin, hmax = np.nanmin(h), np.nanmax(h)
cmap = plottoo.shiftedColorMap(mpl.cm.bwr_r, midpoint=abs(hmin)/(hmax+abs(hmin)))

# plot local fields
im = ax.imshow(h[:,np.newaxis], cmap=cmap)

# layout
ax.plot([-.5,.5], [(L-1)/2, (L-1)/2], c='k', linestyle='--')
labels = [r'\texttt{%s}'%st[1:-1] for st in muts]
ax.set_xticks([])
ax.set_yticks(range(L))
ax.set_yticklabels(labels)
ax.tick_params(labelsize=15)

# colorbar
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.1, 0.7])
cb = fig.colorbar(im, cax=cbar_ax)
cb.ax.tick_params(labelsize=15)
cb.set_ticks([-.5,0.,.5,1.])

# save plot
plottoo.set_size(.625,2)
plt.savefig('output/1b_1.jpg', bbox_inches='tight', pad_inches=0.02, dpi=300)#_singlemode
plt.savefig('output/1b_1.pdf', bbox_inches='tight', pad_inches=0.02)#_singlemode
plt.show()

In [None]:
# sequence site of interest
i = 6

# create figure
fig, ax = plt.subplots(figsize=(3.2, 2.8))

# colors according to state of sequence site of interest
cs = ['C0' if s[i]==0 else 'C1' for s in data.seq]

# scatter empirical vs. model fitness
ax.scatter(phis, data.F1_emp, c=cs, s=5, alpha=.5, zorder=5)
ax.plot(phi_grid, yhat_grid, c='k', lw=2, zorder=5)

# layout
ax.set_xlim(phi_lim)
ax.tick_params(labelsize=15)
ax.set_xlabel('latent phenotype $\phi(\mathbf{s})$', fontsize=15)
ax.set_ylabel('$F_\mathrm{emp}(\mathbf{s})$', fontsize=15)
ax.grid(zorder=5)

# legend
handles = [mpatches.Patch(color=c) for c in ['C0', 'C1', 'k']]
ax.legend(handles=handles, labels=[r'wild-type', r'mutated', r'$g(\phi)$'],
          loc='lower right', fontsize=12)

# save plot
plt.savefig('output/1b_2.jpg', bbox_inches='tight', pad_inches=0.02, dpi=300)
plt.savefig('output/1b_2.pdf', bbox_inches='tight', pad_inches=0.02)
plt.show()

In [None]:
# create landscape object using global fitness model data
lsmodel = lstoo.EmpLS(L=L, q=q, seqs=data.seq, fs=data.F1_model, default=np.nan)

# compute epistatic effect matrix (gamma_ij)
fitness = lambda seq: lsmodel.fitness(seq)
gammaijs = lstoo.gammaij(L, seqs, fitness)

In [None]:
# create figure
fig, ax = plt.subplots(figsize=(2.8, 2.8), constrained_layout=True)

# plot gamma_ij matrix
im = ax.imshow(gammaijs, cmap=mpl.cm.bwr_r, vmin=-1., vmax=1.)

# layout
ax.plot([-1., L+1], [(L-1)/2, (L-1)/2], c='k', linestyle='--')
ax.plot([(L-1)/2, (L-1)/2], [-1., L+1], c='k', linestyle='--')
labels = [r'\texttt{%s}'%st[1:-1] for st in muts]
ax.set_xticks(range(L))
ax.set_xticklabels(labels, rotation='vertical')#range(1,L+1))
ax.set_yticks(range(L))
ax.set_yticklabels(labels)
ax.tick_params(labelsize=15)
ax.set_xlabel(r'$j$', fontsize=15)
ax.set_ylabel(r'$i$', fontsize=15)
ax.set_xlim([-.5, L-.5])
ax.set_ylim([L-.5, -.5])

# colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=.05)
cb = plt.colorbar(im, cax=cax)
cb.ax.tick_params(labelsize=15)
cb.set_ticks([-1,0,1])

# save plot
plottoo.set_size(2,2)
plt.savefig('output/1b_3.jpg', bbox_inches='tight', pad_inches=0.02, dpi=300)
plt.savefig('output/1b_3.pdf', bbox_inches='tight', pad_inches=0.02)
plt.show()