# figure S3B

In [None]:
import pandas as pd, numpy as np
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# LaTeX font for plots
plt.rcParams.update({
    'font.family': 'serif',  # use serif/main font for text elements
    'text.usetex': True,     # use inline math for ticks
})

In [None]:
# length of sequence and number of spin states
L, q = 10, 2
# number/name of all sites
sites = [26, 27, 28, 31, 35, 50, 53, 56, 57, 58]
# names of key mutations (X=V/L/I)
muts = ['G26E', 'F27X', 'T28I', 'S31R', 'S35T', 'V50L', 'S53P', 'S56T', 'T57A', 'Y58F']

## load fitness models

In [None]:
# load fitness data for specific model inferred from maximum likelihood
data_specific = pd.read_csv('output/1c_fitness_specific.csv')
# load fitness data for specific model inferred from Walsh-Hadamard transform
data_walsh = pd.read_csv('output/1c_fitness_walsh.csv')
# load fitness data for global model
data_global = pd.read_csv('output/1b_fitness_global.csv')

## plot dataset correlations

In [None]:
def histo2D(xs, ys):
    # function to compute 2D histogram of a scatter
    nn, locx, locy = np.histogram2d(xs, ys, bins=np.linspace(-4., 2., 50))
    z = np.array([nn[np.argmax(a<=locx[1:]),np.argmax(b<=locy[1:])] for a,b in zip(xs, ys)])
    idx = z.argsort()
    xs2, ys2, z2 = xs[idx], ys[idx], z[idx]
    return xs2, ys2, z2

In [None]:
# create figure
fig, ax = plt.subplots(figsize=(3.*2, 2.8*2), ncols=2, nrows=2, constrained_layout=True)

# ---
# empirical (replicate 1) vs specific learning (replicate 1)
xs = data_specific.F1_emp - data_specific.F1_emp[0]
ys = data_specific.F1_model
xs2, ys2, z2 = histo2D(xs, ys)
im = ax[1,1].scatter(xs2, ys2, c=z2, cmap='jet', marker='.', norm=LogNorm())

ax[1,1].set_xlabel(r'data', fontsize=15)
ax[1,1].set_ylabel(r'specific model', fontsize=15)
rsq = pearsonr(xs.astype(float), ys.astype(float))[0]**2
t = ax[1,1].text(1, -2.75, r'$R^2 = %.2f$'%(rsq), fontsize=15, ha='center', va='center')
t.set_bbox(dict(facecolor='white', alpha=1., edgecolor='white'))

# ---
# empirical (replicate 1) vs empirical (replicate 2)
xs = data_specific.F1_emp - data_specific.F1_emp[0]
ys = data_specific.F2_emp - data_specific.F2_emp[0]
xs2, ys2, z2 = histo2D(xs, ys)
im = ax[0,0].scatter(xs2, ys2, c=z2, cmap='jet', marker='.', norm=LogNorm())

ax[0,0].set_xlabel(r'data (replicate 1)', fontsize=15)
ax[0,0].set_ylabel(r'data (replicate 2)', fontsize=15)
rsq = pearsonr(xs.astype(float), ys.astype(float))[0]**2
t = ax[0,0].text(1, -2.75, r'$R^2 = %.2f$'%(rsq), fontsize=15, ha='center', va='center')
t.set_bbox(dict(facecolor='white', alpha=1., edgecolor='white'))

# ---
# specific learning (replicate 1) vs specific learning (replicate 2)
xs = data_specific.F1_model
ys = data_specific.F2_model
xs2, ys2, z2 = histo2D(xs, ys)
im = ax[0,1].scatter(xs2, ys2, c=z2, cmap='jet', marker='.', norm=LogNorm())

ax[0,1].set_xlabel(r'specific model (replicate 1)', fontsize=15)
ax[0,1].set_ylabel(r'specific model (replicate 2)', fontsize=15)
rsq = pearsonr(xs.astype(float), ys.astype(float))[0]**2
t = ax[0,1].text(1, -2.75, r'$R^2 = %.2f$'%(rsq), fontsize=15, ha='center', va='center')
t.set_bbox(dict(facecolor='white', alpha=1., edgecolor='white'))

# ---
# empirical (replicate 1) vs global (replicate 1)
xs = data_specific.F1_emp - data_specific.F1_emp[0]
ys = data_global.F1_model
xs2, ys2, z2 = histo2D(xs, ys)
im = ax[1,0].scatter(xs2, ys2, c=z2, cmap='jet', marker='.', norm=LogNorm())

ax[1,0].set_xlabel(r'data', fontsize=15)
ax[1,0].set_ylabel(r'global model', fontsize=15)
rsq = pearsonr(xs.astype(float), ys.astype(float))[0]**2
t = ax[1,0].text(1, -2.75, r'$R^2 = %.2f$'%(rsq), fontsize=15, ha='center', va='center')
t.set_bbox(dict(facecolor='white', alpha=1., edgecolor='white'))

# --- inset
ax2 = ax[1,1].inset_axes([0.1,0.575,0.35,0.35])
# empirical (replicate 1) vs specific Walsh-Hadamard (replicate 1)
xs = data_specific.F1_model
ys = data_walsh.F1_model - data_walsh.F1_model[0]
xs2, ys2, z2 = histo2D(xs, ys)
im = ax2.scatter(xs2, ys2, c=z2, cmap='jet', marker='.', norm=LogNorm())

ax2.plot([-100,100], [-100,100], c='k')
ax2.set_xlim([-3.5, 2.5])
ax2.set_ylim([-3.5, 2.5])
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax2.set_xlabel(r'learning', fontsize=9, labelpad=-2.5)
ax2.set_ylabel(r'Walsh-Hadamard', fontsize=9, labelpad=-5)
ax2.grid()
rsq = pearsonr(xs.astype(float), ys.astype(float))[0]**2
t = ax2.text(.5, -3., r'$R^2 = %.2f$'%(rsq), fontsize=8, ha='center', va='center')
t.set_bbox(dict(facecolor='white', alpha=0., edgecolor='white'))

# layout
for i in range(2):
    for j in range(2):
        ax[i,j].plot([-100,100], [-100,100], c='k')
        ax[i,j].set_xlim([-3.5, 2.5])
        ax[i,j].set_ylim([-3.5, 2.5])
        ax[i,j].tick_params(labelsize=15)
        ax[i,j].grid()

# save plot
plt.savefig('output/s3b_1.jpg', bbox_inches='tight', pad_inches=0.02)
plt.savefig('output/s3b_1.pdf', bbox_inches='tight', pad_inches=0.02)
plt.show()