In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

In [2]:
from src.prob_scoring import*
from src.useful_functions import*
from src.data_loader import*
from src.handle_mutations import*
from src.mask_position import*
from src.landscape import*

In [3]:
read_dirs_paths('dir_paths.txt', globals())

Created variables:
data_a4_human = ./data/A4_HUMAN_Seuma_2021.csv
data_prgym_reference = ./data/ProteinGym_reference_file_substitutions.csv
filename1 = A4_HUMAN_Seuma_2021.csv
filename2 = ProteinGym_reference_file_substitutions.csv


In [4]:
df = load_dms(data_a4_human)
df.head()

Unnamed: 0,mutant,fitness,fitness_bin
0,D672V:V710G,-4.70038,0
1,D672G:I712L,-1.531949,1
2,D672G:I712K,-2.673194,0
3,D672G:I703V,-1.621747,1
4,D672G:I703T,-3.131144,0


In [5]:
wt_seq = get_wt_sequence(
    data_prgym_reference, 
    filename1
)

In [6]:
df_single = get_single_mutants(df)
df_double = get_double_mutants(df)

In [7]:
df["sequence"] = df["mutant"].apply(lambda m: apply_mutations(wt_seq, m))
df.head()

Unnamed: 0,mutant,fitness,fitness_bin,sequence
0,D672V:V710G,-4.70038,0,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
1,D672G:I712L,-1.531949,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
2,D672G:I712K,-2.673194,0,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
3,D672G:I703V,-1.621747,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
4,D672G:I703T,-3.131144,0,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...


In [8]:
df_single["sequence"] = df_single["mutant"].apply(lambda m: apply_mutations(wt_seq, m))
df_single.head()

Unnamed: 0,mutant,fitness,fitness_bin,sequence
124,D672N,0.3525,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
125,D672K,-0.117352,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
126,D672I,-2.40434,0,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
208,D672H,0.133308,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
312,D672E,0.017479,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...


In [9]:
df_double["sequence"] = df_double["mutant"].apply(lambda m: apply_mutations(wt_seq, m))
df_double.head()

Unnamed: 0,mutant,fitness,fitness_bin,sequence
0,D672V:V710G,-4.70038,0,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
1,D672G:I712L,-1.531949,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
2,D672G:I712K,-2.673194,0,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
3,D672G:I703V,-1.621747,1,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...
4,D672G:I703T,-3.131144,0,MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMN...


In [10]:
model_id = "facebook/esm2_t33_650M_UR50D"
device = "cuda"

tokenizer = SeqTokenizer(model_id, device=device)
model = tokenizer.model
mask_id = tokenizer.mask_id

In [11]:
# df["PLL"] = batch_pll(df["sequence"].tolist(), tokenizer, model)

In [None]:
df_single["PLL"] = batch_pll(df_single["sequence"].tolist(), tokenizer, model)

In [None]:
# df_double["PLL"] = batch_pll(df_double["sequence"].tolist(), tokenizer, model)

In [None]:
import scipy.stats
import numpy as np
import matplotlib.pyplot as plt

In [None]:
df_clean = df_single.dropna(subset=['fitness', 'PLL'])
df_clean = df_clean[~np.isinf(df_clean['PLL'])]

In [None]:
corr, pval = scipy.stats.spearmanr(df_clean['fitness'], df_clean['PLL'])

print(f"Spearman Correlation: {corr:.4f}")
print(f"P-value: {pval:.4e}")

plt.figure(figsize=(10, 6))

plt.scatter(df_clean['fitness'], df_clean['PLL'], alpha=0.5, s=10, c='blue', label='Mutants')

In [None]:
z = np.polyfit(df_clean['fitness'], df_clean['PLL'], 1)
p = np.poly1d(z)
x_range = np.linspace(df_clean['fitness'].min(), df_clean['fitness'].max(), 100)
plt.plot(x_range, p(x_range), "r--", linewidth=2, label=f'Trend (œÅ={corr:.3f})')

In [None]:
plt.xlabel('Real Lab Fitness Score (DMS)', fontsize=12)
plt.ylabel('ESM-2 Pseudo-Log-Likelihood (PLL)', fontsize=12)
plt.title(f'Fitness vs. PLL Score\nSpearman Correlation = {corr:.3f}', fontsize=14)
plt.grid(True, alpha=0.3)
plt.legend()

In [None]:
plt.tight_layout()
plt.savefig("pll_correlation_plot.png", dpi=300)
plt.show()