# ESM Model

In [1]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [2]:
!export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

In [3]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()

In [4]:
import pandas as pd

df = pd.read_parquet("../datasets/df_string_balanced.parquet")

In [5]:
df.head()

Unnamed: 0,protein1,protein2,label,sequence1,sequence2
0,9606.ENSP00000000233,9606.ENSP00000257770,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MCPRAARAPATLLLALGAVLWPAAGAWELTILHTNDVHSRLEQTSE...
1,9606.ENSP00000000233,9606.ENSP00000226004,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MSGSFELSVQDLNDLLSDGSGCYSLPSQPCNEVTPRIYVGNASVAQ...
2,9606.ENSP00000000233,9606.ENSP00000262455,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MHPAVFLSLPDLRCSLLLLVTWVFTPVTTEITSLDTENIDEILNNA...
3,9606.ENSP00000000233,9606.ENSP00000263265,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MEGSRPRSSLSLASSASTISSLSSLSPKKPTRAVNKIHAFGKRGNA...
4,9606.ENSP00000000233,9606.ENSP00000365686,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MGMSKSHSFFGYPLSIFFIVVNEFCERFSYYGMRAILILYFTNFIS...


In [6]:
from models.esm.embedder import ESMEmbedder

output_dir = "../processed_data/esm_embeddings"

if not os.path.exists(output_dir):
    seqs1 = df["sequence1"].tolist()
    seqs2 = df["sequence2"].tolist()
    labels = df["label"].to_numpy()

    embedder = ESMEmbedder(batch_size=32)
    X1, X2 = embedder.embed_all(seqs1, seqs2)

    embedder.save(X1, X2, labels, output_dir=output_dir)

    # clear variables
    del embedder, seqs1, seqs2, X1, X2, labels
    gc.collect()
    torch.cuda.empty_cache()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Encoding with ESM-2: 100%|██████████| 57139/57139 [8:49:12<00:00,  1.80it/s]   


---Saved: ../processed_data/esm_embeddings/esm_protein1.pt
---Saved: ../processed_data/esm_embeddings/esm_protein2.pt
---Saved: ../processed_data/esm_embeddings/esm_labels.pt


In [4]:
from models.esm.train_pipeline import train_model
from models.esm.dataset import ESMDataset
from torch.utils.data import DataLoader, random_split

data_dir = "../processed_data/esm_embeddings"
batch_size = 64

dataset = ESMDataset(data_dir)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
test_loader = DataLoader(test_ds, batch_size=batch_size)

model = train_model(
    train_loader=train_loader,
    val_loader=val_loader,
    model_save_path="../models",
    input_dim=320,
    hidden_dim=512,
    lr=1e-3,
    epochs=50,
    patience=6
)


Epoch 1/50


                                                                

Train loss: 0.5232


                                                                

Val loss: 0.4801 | Accuracy: 0.7672
-- Best model saved.

Epoch 2/50


                                                                

Train loss: 0.4733


                                                                

Val loss: 0.4469 | Accuracy: 0.7904
-- Best model saved.

Epoch 3/50


                                                                

Train loss: 0.4461


                                                                

Val loss: 0.4264 | Accuracy: 0.8016
-- Best model saved.

Epoch 4/50


                                                                

Train loss: 0.4271


                                                                

Val loss: 0.4136 | Accuracy: 0.8102
-- Best model saved.

Epoch 5/50


                                                                

Train loss: 0.4107


                                                                

Val loss: 0.4037 | Accuracy: 0.8160
-- Best model saved.

Epoch 6/50


                                                                

Train loss: 0.3983


                                                                

Val loss: 0.3950 | Accuracy: 0.8206
-- Best model saved.

Epoch 7/50


                                                                

Train loss: 0.3874


                                                                

Val loss: 0.3893 | Accuracy: 0.8243
-- Best model saved.

Epoch 8/50


                                                                

Train loss: 0.3778


                                                                

Val loss: 0.3863 | Accuracy: 0.8264
-- Best model saved.

Epoch 9/50


                                                                

Train loss: 0.3688


                                                                

Val loss: 0.3821 | Accuracy: 0.8289
-- Best model saved.

Epoch 10/50


                                                                

Train loss: 0.3608


                                                                

Val loss: 0.3789 | Accuracy: 0.8301
-- Best model saved.

Epoch 11/50


                                                                

Train loss: 0.3542


                                                                

Val loss: 0.3762 | Accuracy: 0.8319
-- Best model saved.

Epoch 12/50


                                                                

Train loss: 0.3480


                                                                

Val loss: 0.3742 | Accuracy: 0.8329
-- Best model saved.

Epoch 13/50


                                                                

Train loss: 0.3423


                                                                

Val loss: 0.3721 | Accuracy: 0.8339
-- Best model saved.

Epoch 14/50


                                                                

Train loss: 0.3365


                                                                

Val loss: 0.3726 | Accuracy: 0.8348

Epoch 15/50


                                                                

Train loss: 0.3312


                                                                

Val loss: 0.3694 | Accuracy: 0.8365
-- Best model saved.

Epoch 16/50


                                                                

Train loss: 0.3275


                                                                

Val loss: 0.3710 | Accuracy: 0.8368

Epoch 17/50


                                                                

Train loss: 0.3224


                                                                

Val loss: 0.3708 | Accuracy: 0.8365

Epoch 18/50


                                                                

Train loss: 0.3184


                                                                

Val loss: 0.3713 | Accuracy: 0.8371

Epoch 19/50


                                                                

Train loss: 0.3149


                                                                

Val loss: 0.3681 | Accuracy: 0.8390
-- Best model saved.

Epoch 20/50


                                                                

Train loss: 0.3105


                                                                

Val loss: 0.3663 | Accuracy: 0.8386
-- Best model saved.

Epoch 21/50


                                                                

Train loss: 0.3069


                                                                

Val loss: 0.3635 | Accuracy: 0.8396
-- Best model saved.

Epoch 22/50


                                                                

Train loss: 0.3040


                                                                

Val loss: 0.3686 | Accuracy: 0.8394

Epoch 23/50


                                                                

Train loss: 0.3010


                                                                

Val loss: 0.3670 | Accuracy: 0.8384

Epoch 24/50


                                                                

Train loss: 0.2978


                                                                

Val loss: 0.3643 | Accuracy: 0.8407

Epoch 25/50


                                                                

Train loss: 0.2950


                                                                

Val loss: 0.3660 | Accuracy: 0.8415

Epoch 26/50


                                                                

Train loss: 0.2917


                                                                

Val loss: 0.3675 | Accuracy: 0.8405

Epoch 27/50


                                                                

Train loss: 0.2902


                                                                

Val loss: 0.3648 | Accuracy: 0.8409
!! Early stopping at epoch 27


In [5]:
from models.esm.evaluate import evaluate_model

evaluate_model(model, test_loader)

Test Accuracy : 0.8400
Test F1 Score : 0.8369
Test ROC AUC  : 0.9186
