# ESM Model

In [1]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [2]:
!export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

In [3]:
import gc
import torch
gc.collect()
torch.cuda.empty_cache()

In [4]:
import pandas as pd

df = pd.read_parquet("../datasets/df_string_balanced.parquet")

In [5]:
df.head()

Unnamed: 0,protein1,protein2,label,sequence1,sequence2
0,9606.ENSP00000000233,9606.ENSP00000257770,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MCPRAARAPATLLLALGAVLWPAAGAWELTILHTNDVHSRLEQTSE...
1,9606.ENSP00000000233,9606.ENSP00000226004,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MSGSFELSVQDLNDLLSDGSGCYSLPSQPCNEVTPRIYVGNASVAQ...
2,9606.ENSP00000000233,9606.ENSP00000262455,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MHPAVFLSLPDLRCSLLLLVTWVFTPVTTEITSLDTENIDEILNNA...
3,9606.ENSP00000000233,9606.ENSP00000263265,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MEGSRPRSSLSLASSASTISSLSSLSPKKPTRAVNKIHAFGKRGNA...
4,9606.ENSP00000000233,9606.ENSP00000365686,1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MGMSKSHSFFGYPLSIFFIVVNEFCERFSYYGMRAILILYFTNFIS...


In [6]:
from models.esm.embedder import ESMEmbedder

output_dir = "../processed_data/esm_embeddings"

if not os.path.exists(output_dir):
    seqs1 = df["sequence1"].tolist()
    seqs2 = df["sequence2"].tolist()
    labels = df["label"].to_numpy()

    embedder = ESMEmbedder(batch_size=32)
    X1, X2 = embedder.embed_all(seqs1, seqs2)

    embedder.save(X1, X2, labels, output_dir=output_dir)

    # clear variables
    del embedder, seqs1, seqs2, X1, X2, labels
    gc.collect()
    torch.cuda.empty_cache()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Encoding with ESM-2: 100%|██████████| 57139/57139 [8:49:12<00:00,  1.80it/s]   


---Saved: ../processed_data/esm_embeddings/esm_protein1.pt
---Saved: ../processed_data/esm_embeddings/esm_protein2.pt
---Saved: ../processed_data/esm_embeddings/esm_labels.pt


In [None]:
from models.esm.train_pipeline import train_model
from models.esm.dataset import ESMDataset
from torch.utils.data import DataLoader, random_split

data_dir = "../processed_data"
batch_size = 64

dataset = ESMDataset(data_dir)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)
test_loader = DataLoader(test_ds, batch_size=batch_size)

model = train_model(
    train_loader=train_loader,
    val_loader=val_loader,
    model_save_path="../models",
    input_dim=1280,
    hidden_dim=512,
    lr=1e-3,
    epochs=50,
    patience=3
)

In [None]:
from models.esm.evaluate import evaluate_model

evaluate_model(model, test_loader)